diff --git a/.github/workflows/java-ci.yml b/.github/workflows/java-ci.yml index 1da7a673a865..bb99e95a12a1 100644 --- a/.github/workflows/java-ci.yml +++ b/.github/workflows/java-ci.yml @@ -95,7 +95,7 @@ jobs: runs-on: ubuntu-22.04 strategy: matrix: - jvm: [11, 17, 21] + jvm: [17, 21] steps: - uses: actions/checkout@v4 - uses: actions/setup-java@v4 @@ -108,7 +108,7 @@ jobs: runs-on: ubuntu-22.04 strategy: matrix: - jvm: [11, 17, 21] + jvm: [17, 21] steps: - uses: actions/checkout@v4 - uses: actions/setup-java@v4 diff --git a/.github/workflows/jmh-benchmarks.yml b/.github/workflows/jmh-benchmarks.yml index cfb53513e743..6c0b10078053 100644 --- a/.github/workflows/jmh-benchmarks.yml +++ b/.github/workflows/jmh-benchmarks.yml @@ -28,8 +28,8 @@ on: description: 'The branch name' required: true spark_version: - description: 'The spark project version to use, such as iceberg-spark-3.5' - default: 'iceberg-spark-3.5' + description: 'The spark project version to use, such as iceberg-spark-4.0' + default: 'iceberg-spark-4.0' required: true benchmarks: description: 'A list of comma-separated double-quoted Benchmark names, such as "IcebergSourceFlatParquetDataReadBenchmark", "IcebergSourceFlatParquetDataFilterBenchmark"' diff --git a/.github/workflows/publish-snapshot.yml b/.github/workflows/publish-snapshot.yml index 7ff6b56da576..62cd9cd38706 100644 --- a/.github/workflows/publish-snapshot.yml +++ b/.github/workflows/publish-snapshot.yml @@ -41,4 +41,4 @@ jobs: - run: | ./gradlew printVersion ./gradlew -DallModules publishApachePublicationToMavenRepository -PmavenUser=${{ secrets.NEXUS_USER }} -PmavenPassword=${{ secrets.NEXUS_PW }} - ./gradlew -DflinkVersions= -DsparkVersions=3.3,3.4,3.5 -DscalaVersion=2.13 -DkafkaVersions=3 -DhiveVersions= publishApachePublicationToMavenRepository -PmavenUser=${{ secrets.NEXUS_USER }} -PmavenPassword=${{ secrets.NEXUS_PW }} + ./gradlew -DflinkVersions= -DsparkVersions=3.3,3.4,3.5,4.0 -DscalaVersion=2.13 -DkafkaVersions=3 -DhiveVersions= publishApachePublicationToMavenRepository -PmavenUser=${{ secrets.NEXUS_USER }} -PmavenPassword=${{ secrets.NEXUS_PW }} diff --git a/.github/workflows/recurring-jmh-benchmarks.yml b/.github/workflows/recurring-jmh-benchmarks.yml index 71a52640b2f6..eec1a4aac128 100644 --- a/.github/workflows/recurring-jmh-benchmarks.yml +++ b/.github/workflows/recurring-jmh-benchmarks.yml @@ -41,7 +41,7 @@ jobs: "IcebergSourceNestedParquetDataReadBenchmark", "IcebergSourceNestedParquetDataWriteBenchmark", "IcebergSourceParquetEqDeleteBenchmark", "IcebergSourceParquetMultiDeleteFileBenchmark", "IcebergSourceParquetPosDeleteBenchmark", "IcebergSourceParquetWithUnrelatedDeleteBenchmark"] - spark_version: ['iceberg-spark-3.5'] + spark_version: ['iceberg-spark-4.0'] env: SPARK_LOCAL_IP: localhost steps: diff --git a/.github/workflows/spark-ci.yml b/.github/workflows/spark-ci.yml index 0d7bd2d3d3e7..d3c36cd54f1a 100644 --- a/.github/workflows/spark-ci.yml +++ b/.github/workflows/spark-ci.yml @@ -73,15 +73,19 @@ jobs: strategy: matrix: jvm: [11, 17, 21] - spark: ['3.3', '3.4', '3.5'] - scala: ['2.12', '2.13'] + spark: [ '3.3', '3.4', '3.5', '4.0' ] + scala: [ '2.12', '2.13' ] exclude: # Spark 3.5 is the first version not failing on Java 21 (https://issues.apache.org/jira/browse/SPARK-42369) # Full Java 21 support is coming in Spark 4 (https://issues.apache.org/jira/browse/SPARK-43831) + - jvm: 11 + spark: '4.0' - jvm: 21 spark: '3.3' - jvm: 21 spark: '3.4' + - spark: '4.0' + scala: '2.12' env: SPARK_LOCAL_IP: localhost steps: diff --git a/.gitignore b/.gitignore index e4c9e1a16a27..26377425f1e0 100644 --- a/.gitignore +++ b/.gitignore @@ -37,6 +37,8 @@ spark/v3.4/spark/benchmark/* spark/v3.4/spark-extensions/benchmark/* spark/v3.5/spark/benchmark/* spark/v3.5/spark-extensions/benchmark/* +spark/v4.0/spark/benchmark/* +spark/v4.0/spark-extensions/benchmark/* */benchmark/* __pycache__/ diff --git a/build.gradle b/build.gradle index 81daf14a357f..09c85cbae045 100644 --- a/build.gradle +++ b/build.gradle @@ -119,6 +119,9 @@ allprojects { repositories { mavenCentral() mavenLocal() + maven { + url "https://repository.apache.org/content/repositories/snapshots/" + } } } diff --git a/gradle.properties b/gradle.properties index dc1e1a509b01..b7aff6b17169 100644 --- a/gradle.properties +++ b/gradle.properties @@ -20,8 +20,8 @@ systemProp.defaultFlinkVersions=1.20 systemProp.knownFlinkVersions=1.18,1.19,1.20 systemProp.defaultHiveVersions=2 systemProp.knownHiveVersions=2,3 -systemProp.defaultSparkVersions=3.5 -systemProp.knownSparkVersions=3.3,3.4,3.5 +systemProp.defaultSparkVersions=4.0 +systemProp.knownSparkVersions=3.3,3.4,3.5,4.0 systemProp.defaultKafkaVersions=3 systemProp.knownKafkaVersions=3 systemProp.defaultScalaVersion=2.12 diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 8af0d6ec6ab2..8d5f202180a8 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -23,6 +23,7 @@ activation = "1.1.1" aliyun-sdk-oss = "3.10.2" antlr = "4.9.3" +antlr413 = "4.13.1" # For Spark 4.0 support aircompressor = "0.27" apiguardian = "1.1.2" arrow = "15.0.2" @@ -48,6 +49,7 @@ google-libraries-bom = "26.50.0" guava = "33.3.1-jre" hadoop2 = "2.7.3" hadoop3 = "3.4.1" +hadoop34 = "3.4.0" # For Spark 4.0 support httpcomponents-httpclient5 = "5.4.1" hive2 = { strictly = "2.3.9"} # see rich version usage explanation above hive3 = "3.1.3" @@ -83,6 +85,7 @@ snowflake-jdbc = "3.20.0" spark-hive33 = "3.3.4" spark-hive34 = "3.4.4" spark-hive35 = "3.5.2" +spark-hive40 = "4.0.0-SNAPSHOT" sqlite-jdbc = "3.47.0.0" testcontainers = "1.20.3" tez010 = "0.10.4" @@ -94,6 +97,8 @@ aircompressor = { module = "io.airlift:aircompressor", version.ref = "aircompres aliyun-sdk-oss = { module = "com.aliyun.oss:aliyun-sdk-oss", version.ref = "aliyun-sdk-oss" } antlr-antlr4 = { module = "org.antlr:antlr4", version.ref = "antlr" } antlr-runtime = { module = "org.antlr:antlr4-runtime", version.ref = "antlr" } +antlr-antlr413 = { module = "org.antlr:antlr4", version.ref = "antlr413" } +antlr-runtime413 = { module = "org.antlr:antlr4-runtime", version.ref = "antlr413" } arrow-memory-netty = { module = "org.apache.arrow:arrow-memory-netty", version.ref = "arrow" } arrow-vector = { module = "org.apache.arrow:arrow-vector", version.ref = "arrow" } avro-avro = { module = "org.apache.avro:avro", version.ref = "avro" } @@ -135,6 +140,7 @@ hadoop2-mapreduce-client-core = { module = "org.apache.hadoop:hadoop-mapreduce-c hadoop2-minicluster = { module = "org.apache.hadoop:hadoop-minicluster", version.ref = "hadoop2" } hadoop3-client = { module = "org.apache.hadoop:hadoop-client", version.ref = "hadoop3" } hadoop3-common = { module = "org.apache.hadoop:hadoop-common", version.ref = "hadoop3" } +hadoop34-minicluster = { module = "org.apache.hadoop:hadoop-minicluster", version.ref = "hadoop34" } hive2-exec = { module = "org.apache.hive:hive-exec", version.ref = "hive2" } hive2-metastore = { module = "org.apache.hive:hive-metastore", version.ref = "hive2" } hive2-serde = { module = "org.apache.hive:hive-serde", version.ref = "hive2" } diff --git a/jmh.gradle b/jmh.gradle index a5d8d624270d..ac93771d0d03 100644 --- a/jmh.gradle +++ b/jmh.gradle @@ -53,6 +53,11 @@ if (sparkVersions.contains("3.5")) { jmhProjects.add(project(":iceberg-spark:iceberg-spark-extensions-3.5_${scalaVersion}")) } +if (sparkVersions.contains("4.0")) { + jmhProjects.add(project(":iceberg-spark:iceberg-spark-4.0_2.13")) + jmhProjects.add(project(":iceberg-spark:iceberg-spark-extensions-4.0_2.13")) +} + configure(jmhProjects) { apply plugin: 'me.champeau.jmh' apply plugin: 'io.morethan.jmhreport' diff --git a/settings.gradle b/settings.gradle index 103741389a26..66ff30315e42 100644 --- a/settings.gradle +++ b/settings.gradle @@ -180,6 +180,18 @@ if (sparkVersions.contains("3.5")) { project(":iceberg-spark:spark-runtime-3.5_${scalaVersion}").name = "iceberg-spark-runtime-3.5_${scalaVersion}" } +if (sparkVersions.contains("4.0")) { + include ":iceberg-spark:spark-4.0_2.13" + include ":iceberg-spark:spark-extensions-4.0_2.13" + include ":iceberg-spark:spark-runtime-4.0_2.13" + project(":iceberg-spark:spark-4.0_2.13").projectDir = file('spark/v4.0/spark') + project(":iceberg-spark:spark-4.0_2.13").name = "iceberg-spark-4.0_2.13" + project(":iceberg-spark:spark-extensions-4.0_2.13").projectDir = file('spark/v4.0/spark-extensions') + project(":iceberg-spark:spark-extensions-4.0_2.13").name = "iceberg-spark-extensions-4.0_2.13" + project(":iceberg-spark:spark-runtime-4.0_2.13").projectDir = file('spark/v4.0/spark-runtime') + project(":iceberg-spark:spark-runtime-4.0_2.13").name = "iceberg-spark-runtime-4.0_2.13" +} + // hive 3 depends on hive 2, so always add hive 2 if hive3 is enabled if (hiveVersions.contains("2") || hiveVersions.contains("3")) { include 'mr' diff --git a/spark/build.gradle b/spark/build.gradle index c2bc5f8a14ed..160892432a01 100644 --- a/spark/build.gradle +++ b/spark/build.gradle @@ -31,3 +31,8 @@ if (sparkVersions.contains("3.4")) { if (sparkVersions.contains("3.5")) { apply from: file("$projectDir/v3.5/build.gradle") } + + +if (sparkVersions.contains("4.0")) { + apply from: file("$projectDir/v4.0/build.gradle") +} \ No newline at end of file diff --git a/spark/v4.0/build.gradle b/spark/v4.0/build.gradle new file mode 100644 index 000000000000..c5ae5da0be79 --- /dev/null +++ b/spark/v4.0/build.gradle @@ -0,0 +1,322 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +String sparkMajorVersion = '4.0' +String scalaVersion = '2.13' + +def sparkProjects = [ + project(":iceberg-spark:iceberg-spark-${sparkMajorVersion}_${scalaVersion}"), + project(":iceberg-spark:iceberg-spark-extensions-${sparkMajorVersion}_${scalaVersion}"), + project(":iceberg-spark:iceberg-spark-runtime-${sparkMajorVersion}_${scalaVersion}"), +] + +configure(sparkProjects) { + configurations { + all { + resolutionStrategy { + force "com.fasterxml.jackson.module:jackson-module-scala_${scalaVersion}:${libs.versions.jackson215.get()}" + force "com.fasterxml.jackson.core:jackson-databind:${libs.versions.jackson215.get()}" + force "com.fasterxml.jackson.core:jackson-core:${libs.versions.jackson215.get()}" + } + } + } +} + +project(":iceberg-spark:iceberg-spark-${sparkMajorVersion}_${scalaVersion}") { + apply plugin: 'scala' + apply plugin: 'com.github.alisiikh.scalastyle' + + sourceSets { + main { + scala.srcDirs = ['src/main/scala', 'src/main/java'] + java.srcDirs = [] + } + } + + dependencies { + implementation project(path: ':iceberg-bundled-guava', configuration: 'shadow') + api project(':iceberg-api') + implementation project(':iceberg-common') + implementation project(':iceberg-core') + implementation project(':iceberg-data') + implementation project(':iceberg-orc') + implementation project(':iceberg-parquet') + implementation project(':iceberg-arrow') + implementation("org.scala-lang.modules:scala-collection-compat_${scalaVersion}:${libs.versions.scala.collection.compat.get()}") + implementation("org.apache.datasketches:datasketches-java:${libs.versions.datasketches.get()}") + if (scalaVersion == '2.12') { + // scala-collection-compat_2.12 pulls scala 2.12.17 and we need 2.12.18 for JDK 21 support + implementation 'org.scala-lang:scala-library:2.12.18' + } + + compileOnly libs.errorprone.annotations + compileOnly libs.avro.avro + compileOnly("org.apache.spark:spark-hive_${scalaVersion}:${libs.versions.spark.hive40.get()}") { + exclude group: 'org.apache.avro', module: 'avro' + exclude group: 'org.apache.arrow' + exclude group: 'org.apache.parquet' + // to make sure netty libs only come from project(':iceberg-arrow') + exclude group: 'io.netty', module: 'netty-buffer' + exclude group: 'io.netty', module: 'netty-common' + exclude group: 'org.roaringbitmap' + } + + implementation libs.parquet.column + implementation libs.parquet.hadoop + + implementation("${libs.orc.core.get().module}:${libs.versions.orc.get()}:nohive") { + exclude group: 'org.apache.hadoop' + exclude group: 'commons-lang' + // These artifacts are shaded and included in the orc-core fat jar + exclude group: 'com.google.protobuf', module: 'protobuf-java' + exclude group: 'org.apache.hive', module: 'hive-storage-api' + } + + implementation(libs.arrow.vector) { + exclude group: 'io.netty', module: 'netty-buffer' + exclude group: 'io.netty', module: 'netty-common' + exclude group: 'com.google.code.findbugs', module: 'jsr305' + } + + implementation libs.caffeine + + // Add BoneCP dependency for test configurations + testImplementation 'com.jolbox:bonecp:0.8.0.RELEASE' + + testImplementation(libs.hadoop34.minicluster) { + exclude group: 'org.apache.avro', module: 'avro' + // to make sure netty libs only come from project(':iceberg-arrow') + exclude group: 'io.netty', module: 'netty-buffer' + exclude group: 'io.netty', module: 'netty-common' + } + testImplementation project(path: ':iceberg-hive-metastore') + testImplementation project(path: ':iceberg-hive-metastore', configuration: 'testArtifacts') + testImplementation project(path: ':iceberg-api', configuration: 'testArtifacts') + testImplementation project(path: ':iceberg-core', configuration: 'testArtifacts') + testImplementation project(path: ':iceberg-data', configuration: 'testArtifacts') + testImplementation libs.sqlite.jdbc + testImplementation libs.awaitility + } + + test { + useJUnitPlatform() + } + + tasks.withType(Test) { + // Vectorized reads need more memory + maxHeapSize '2560m' + } +} + +project(":iceberg-spark:iceberg-spark-extensions-${sparkMajorVersion}_${scalaVersion}") { + apply plugin: 'java-library' + apply plugin: 'scala' + apply plugin: 'com.github.alisiikh.scalastyle' + apply plugin: 'antlr' + + configurations { + /* + The Gradle Antlr plugin erroneously adds both antlr-build and runtime dependencies to the runtime path. This + bug https://github.com/gradle/gradle/issues/820 exists because older versions of Antlr do not have separate + runtime and implementation dependencies and they do not want to break backwards compatibility. So to only end up with + the runtime dependency on the runtime classpath we remove the dependencies added by the plugin here. Then add + the runtime dependency back to only the runtime configuration manually. + */ + implementation { + extendsFrom = extendsFrom.findAll { it != configurations.antlr } + } + } + + dependencies { + implementation("org.scala-lang.modules:scala-collection-compat_${scalaVersion}:${libs.versions.scala.collection.compat.get()}") + if (scalaVersion == '2.12') { + // scala-collection-compat_2.12 pulls scala 2.12.17 and we need 2.12.18 for JDK 21 support + implementation 'org.scala-lang:scala-library:2.12.18' + } + implementation libs.roaringbitmap + + compileOnly "org.scala-lang:scala-library" + compileOnly project(path: ':iceberg-bundled-guava', configuration: 'shadow') + compileOnly project(':iceberg-api') + compileOnly project(':iceberg-core') + compileOnly project(':iceberg-common') + compileOnly project(":iceberg-spark:iceberg-spark-${sparkMajorVersion}_${scalaVersion}") + compileOnly("org.apache.spark:spark-hive_${scalaVersion}:${libs.versions.spark.hive40.get()}") { + exclude group: 'org.apache.avro', module: 'avro' + exclude group: 'org.apache.arrow' + exclude group: 'org.apache.parquet' + // to make sure netty libs only come from project(':iceberg-arrow') + exclude group: 'io.netty', module: 'netty-buffer' + exclude group: 'io.netty', module: 'netty-common' + exclude group: 'org.roaringbitmap' + } + compileOnly libs.errorprone.annotations + + testImplementation project(path: ':iceberg-data') + testImplementation project(path: ':iceberg-parquet') + testImplementation project(path: ':iceberg-hive-metastore') + testImplementation project(path: ':iceberg-api', configuration: 'testArtifacts') + testImplementation project(path: ':iceberg-core', configuration: 'testArtifacts') + testImplementation project(path: ':iceberg-hive-metastore', configuration: 'testArtifacts') + testImplementation project(path: ":iceberg-spark:iceberg-spark-${sparkMajorVersion}_${scalaVersion}", configuration: 'testArtifacts') + // Add BoneCP dependency for test configurations + testImplementation 'com.jolbox:bonecp:0.8.0.RELEASE' + + testImplementation libs.avro.avro + testImplementation libs.parquet.hadoop + testImplementation libs.awaitility + + // Required because we remove antlr plugin dependencies from the compile configuration, see note above + runtimeOnly libs.antlr.runtime413 + antlr libs.antlr.antlr413 + } + + test { + useJUnitPlatform() + } + + generateGrammarSource { + maxHeapSize = "64m" + arguments += ['-visitor', '-package', 'org.apache.spark.sql.catalyst.parser.extensions'] + } +} + +project(":iceberg-spark:iceberg-spark-runtime-${sparkMajorVersion}_${scalaVersion}") { + apply plugin: 'com.gradleup.shadow' + + tasks.jar.dependsOn tasks.shadowJar + + sourceSets { + integration { + java.srcDir "$projectDir/src/integration/java" + resources.srcDir "$projectDir/src/integration/resources" + } + } + + configurations { + implementation { + exclude group: 'org.apache.spark' + // included in Spark + exclude group: 'org.slf4j' + exclude group: 'org.apache.commons' + exclude group: 'commons-pool' + exclude group: 'commons-codec' + exclude group: 'org.xerial.snappy' + exclude group: 'javax.xml.bind' + exclude group: 'javax.annotation' + exclude group: 'com.github.luben' + exclude group: 'com.ibm.icu' + exclude group: 'org.glassfish' + exclude group: 'org.abego.treelayout' + exclude group: 'org.antlr' + exclude group: 'org.scala-lang' + exclude group: 'org.scala-lang.modules' + } + } + + dependencies { + api project(':iceberg-api') + implementation project(":iceberg-spark:iceberg-spark-${sparkMajorVersion}_${scalaVersion}") + implementation project(":iceberg-spark:iceberg-spark-extensions-${sparkMajorVersion}_${scalaVersion}") + implementation project(':iceberg-aws') + implementation project(':iceberg-azure') + implementation(project(':iceberg-aliyun')) { + exclude group: 'edu.umd.cs.findbugs', module: 'findbugs' + exclude group: 'org.apache.httpcomponents', module: 'httpclient' + exclude group: 'commons-logging', module: 'commons-logging' + } + implementation project(':iceberg-gcp') + implementation project(':iceberg-hive-metastore') + implementation(project(':iceberg-nessie')) { + exclude group: 'com.google.code.findbugs', module: 'jsr305' + } + implementation (project(':iceberg-snowflake')) { + exclude group: 'net.snowflake' , module: 'snowflake-jdbc' + } + + integrationImplementation "org.scala-lang.modules:scala-collection-compat_${scalaVersion}:${libs.versions.scala.collection.compat.get()}" + integrationImplementation "org.apache.spark:spark-hive_${scalaVersion}:${libs.versions.spark.hive40.get()}" + integrationImplementation libs.junit.vintage.engine + integrationImplementation libs.junit.jupiter + integrationImplementation libs.slf4j.simple + integrationImplementation libs.assertj.core + integrationImplementation project(path: ':iceberg-api', configuration: 'testArtifacts') + integrationImplementation project(path: ':iceberg-hive-metastore', configuration: 'testArtifacts') + integrationImplementation project(path: ":iceberg-spark:iceberg-spark-${sparkMajorVersion}_${scalaVersion}", configuration: 'testArtifacts') + integrationImplementation project(path: ":iceberg-spark:iceberg-spark-extensions-${sparkMajorVersion}_${scalaVersion}", configuration: 'testArtifacts') + // Not allowed on our classpath, only the runtime jar is allowed + integrationCompileOnly project(":iceberg-spark:iceberg-spark-extensions-${sparkMajorVersion}_${scalaVersion}") + integrationCompileOnly project(":iceberg-spark:iceberg-spark-${sparkMajorVersion}_${scalaVersion}") + integrationCompileOnly project(':iceberg-api') + } + + shadowJar { + configurations = [project.configurations.runtimeClasspath] + + zip64 true + + // include the LICENSE and NOTICE files for the shaded Jar + from(projectDir) { + include 'LICENSE' + include 'NOTICE' + } + + // Relocate dependencies to avoid conflicts + relocate 'com.google.errorprone', 'org.apache.iceberg.shaded.com.google.errorprone' + relocate 'com.google.flatbuffers', 'org.apache.iceberg.shaded.com.google.flatbuffers' + relocate 'com.fasterxml', 'org.apache.iceberg.shaded.com.fasterxml' + relocate 'com.github.benmanes', 'org.apache.iceberg.shaded.com.github.benmanes' + relocate 'org.checkerframework', 'org.apache.iceberg.shaded.org.checkerframework' + relocate 'org.apache.avro', 'org.apache.iceberg.shaded.org.apache.avro' + relocate 'avro.shaded', 'org.apache.iceberg.shaded.org.apache.avro.shaded' + relocate 'com.thoughtworks.paranamer', 'org.apache.iceberg.shaded.com.thoughtworks.paranamer' + relocate 'org.apache.parquet', 'org.apache.iceberg.shaded.org.apache.parquet' + relocate 'shaded.parquet', 'org.apache.iceberg.shaded.org.apache.parquet.shaded' + relocate 'org.apache.orc', 'org.apache.iceberg.shaded.org.apache.orc' + relocate 'io.airlift', 'org.apache.iceberg.shaded.io.airlift' + relocate 'org.apache.hc.client5', 'org.apache.iceberg.shaded.org.apache.hc.client5' + relocate 'org.apache.hc.core5', 'org.apache.iceberg.shaded.org.apache.hc.core5' + // relocate Arrow and related deps to shade Iceberg specific version + relocate 'io.netty', 'org.apache.iceberg.shaded.io.netty' + relocate 'org.apache.arrow', 'org.apache.iceberg.shaded.org.apache.arrow' + relocate 'com.carrotsearch', 'org.apache.iceberg.shaded.com.carrotsearch' + relocate 'org.threeten.extra', 'org.apache.iceberg.shaded.org.threeten.extra' + relocate 'org.roaringbitmap', 'org.apache.iceberg.shaded.org.roaringbitmap' + relocate 'org.apache.datasketches', 'org.apache.iceberg.shaded.org.apache.datasketches' + + archiveClassifier.set(null) + } + + task integrationTest(type: Test) { + useJUnitPlatform() + description = "Test Spark3 Runtime Jar against Spark ${sparkMajorVersion}" + group = "verification" + jvmArgs += project.property('extraJvmArgs') + testClassesDirs = sourceSets.integration.output.classesDirs + classpath = sourceSets.integration.runtimeClasspath + files(shadowJar.archiveFile.get().asFile.path) + inputs.file(shadowJar.archiveFile.get().asFile.path) + } + integrationTest.dependsOn shadowJar + check.dependsOn integrationTest + + jar { + enabled = false + } +} + diff --git a/spark/v4.0/spark-extensions/src/jmh/java/org/apache/iceberg/DeleteFileIndexBenchmark.java b/spark/v4.0/spark-extensions/src/jmh/java/org/apache/iceberg/DeleteFileIndexBenchmark.java new file mode 100644 index 000000000000..9375ca3a4f46 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/jmh/java/org/apache/iceberg/DeleteFileIndexBenchmark.java @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg; + +import com.google.errorprone.annotations.FormatMethod; +import com.google.errorprone.annotations.FormatString; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions; +import org.apache.iceberg.util.ThreadPools; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Timeout; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +/** + * A benchmark that evaluates the delete file index build and lookup performance. + * + *

To run this benchmark for spark-4.0: + * ./gradlew -DsparkVersions=4.0 :iceberg-spark:iceberg-spark-extensions-4.0_2.13:jmh + * -PjmhIncludeRegex=DeleteFileIndexBenchmark + * -PjmhOutputPath=benchmark/iceberg-delete-file-index-benchmark.txt + * + */ +@Fork(1) +@State(Scope.Benchmark) +@Warmup(iterations = 3) +@Measurement(iterations = 10) +@Timeout(time = 20, timeUnit = TimeUnit.MINUTES) +@BenchmarkMode(Mode.SingleShotTime) +public class DeleteFileIndexBenchmark { + + private static final String TABLE_NAME = "test_table"; + private static final String PARTITION_COLUMN = "ss_ticket_number"; + + private static final int NUM_PARTITIONS = 50; + private static final int NUM_DATA_FILES_PER_PARTITION = 50_000; + private static final int NUM_DELETE_FILES_PER_PARTITION = 100; + + private final Configuration hadoopConf = new Configuration(); + private SparkSession spark; + private Table table; + + private List dataFiles; + + @Param({"partition", "file", "dv"}) + private String type; + + @Setup + public void setupBenchmark() throws NoSuchTableException, ParseException { + setupSpark(); + initTable(); + initDataAndDeletes(); + loadDataFiles(); + } + + private void initDataAndDeletes() { + if (type.equals("partition")) { + initDataAndPartitionScopedDeletes(); + } else if (type.equals("file")) { + initDataAndFileScopedDeletes(); + } else { + initDataAndDVs(); + } + } + + @TearDown + public void tearDownBenchmark() { + dropTable(); + tearDownSpark(); + } + + @Benchmark + @Threads(1) + public void buildIndexAndLookup(Blackhole blackhole) { + DeleteFileIndex deletes = buildDeletes(); + for (DataFile dataFile : dataFiles) { + DeleteFile[] deleteFiles = deletes.forDataFile(dataFile.dataSequenceNumber(), dataFile); + blackhole.consume(deleteFiles); + } + } + + private void loadDataFiles() { + table.refresh(); + + Snapshot snapshot = table.currentSnapshot(); + + ManifestGroup manifestGroup = + new ManifestGroup(table.io(), snapshot.dataManifests(table.io()), ImmutableList.of()); + + try (CloseableIterable> entries = manifestGroup.entries()) { + List files = Lists.newArrayList(); + for (ManifestEntry entry : entries) { + files.add(entry.file().copyWithoutStats()); + } + this.dataFiles = files; + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private DeleteFileIndex buildDeletes() { + table.refresh(); + + List deleteManifests = table.currentSnapshot().deleteManifests(table.io()); + + return DeleteFileIndex.builderFor(table.io(), deleteManifests) + .specsById(table.specs()) + .planWith(ThreadPools.getWorkerPool()) + .build(); + } + + private void initDataAndPartitionScopedDeletes() { + for (int partitionOrdinal = 0; partitionOrdinal < NUM_PARTITIONS; partitionOrdinal++) { + StructLike partition = TestHelpers.Row.of(partitionOrdinal); + + RowDelta rowDelta = table.newRowDelta(); + + for (int fileOrdinal = 0; fileOrdinal < NUM_DATA_FILES_PER_PARTITION; fileOrdinal++) { + DataFile dataFile = FileGenerationUtil.generateDataFile(table, partition); + rowDelta.addRows(dataFile); + } + + for (int fileOrdinal = 0; fileOrdinal < NUM_DELETE_FILES_PER_PARTITION; fileOrdinal++) { + DeleteFile deleteFile = FileGenerationUtil.generatePositionDeleteFile(table, partition); + rowDelta.addDeletes(deleteFile); + } + + rowDelta.commit(); + } + } + + private void initDataAndFileScopedDeletes() { + for (int partitionOrdinal = 0; partitionOrdinal < NUM_PARTITIONS; partitionOrdinal++) { + StructLike partition = TestHelpers.Row.of(partitionOrdinal); + + RowDelta rowDelta = table.newRowDelta(); + + for (int fileOrdinal = 0; fileOrdinal < NUM_DATA_FILES_PER_PARTITION; fileOrdinal++) { + DataFile dataFile = FileGenerationUtil.generateDataFile(table, partition); + DeleteFile deleteFile = FileGenerationUtil.generatePositionDeleteFile(table, dataFile); + rowDelta.addRows(dataFile); + rowDelta.addDeletes(deleteFile); + } + + rowDelta.commit(); + } + } + + private void initDataAndDVs() { + for (int partitionOrdinal = 0; partitionOrdinal < NUM_PARTITIONS; partitionOrdinal++) { + StructLike partition = TestHelpers.Row.of(partitionOrdinal); + + RowDelta rowDelta = table.newRowDelta(); + + for (int fileOrdinal = 0; fileOrdinal < NUM_DATA_FILES_PER_PARTITION; fileOrdinal++) { + DataFile dataFile = FileGenerationUtil.generateDataFile(table, partition); + DeleteFile dv = FileGenerationUtil.generateDV(table, dataFile); + rowDelta.addRows(dataFile); + rowDelta.addDeletes(dv); + } + + rowDelta.commit(); + } + } + + private void setupSpark() { + this.spark = + SparkSession.builder() + .config("spark.ui.enabled", false) + .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .config("spark.sql.extensions", IcebergSparkSessionExtensions.class.getName()) + .config("spark.sql.catalog.spark_catalog", SparkSessionCatalog.class.getName()) + .config("spark.sql.catalog.spark_catalog.type", "hadoop") + .config("spark.sql.catalog.spark_catalog.warehouse", newWarehouseDir()) + .master("local[*]") + .getOrCreate(); + } + + private void tearDownSpark() { + spark.stop(); + } + + private void initTable() throws NoSuchTableException, ParseException { + sql( + "CREATE TABLE %s ( " + + " `ss_sold_date_sk` INT, " + + " `ss_sold_time_sk` INT, " + + " `ss_item_sk` INT, " + + " `ss_customer_sk` STRING, " + + " `ss_cdemo_sk` STRING, " + + " `ss_hdemo_sk` STRING, " + + " `ss_addr_sk` STRING, " + + " `ss_store_sk` STRING, " + + " `ss_promo_sk` STRING, " + + " `ss_ticket_number` INT, " + + " `ss_quantity` STRING, " + + " `ss_wholesale_cost` STRING, " + + " `ss_list_price` STRING, " + + " `ss_sales_price` STRING, " + + " `ss_ext_discount_amt` STRING, " + + " `ss_ext_sales_price` STRING, " + + " `ss_ext_wholesale_cost` STRING, " + + " `ss_ext_list_price` STRING, " + + " `ss_ext_tax` STRING, " + + " `ss_coupon_amt` STRING, " + + " `ss_net_paid` STRING, " + + " `ss_net_paid_inc_tax` STRING, " + + " `ss_net_profit` STRING " + + ")" + + "USING iceberg " + + "PARTITIONED BY (%s) " + + "TBLPROPERTIES (" + + " '%s' '%b'," + + " '%s' '%s'," + + " '%s' '%d')", + TABLE_NAME, + PARTITION_COLUMN, + TableProperties.MANIFEST_MERGE_ENABLED, + false, + TableProperties.DELETE_MODE, + RowLevelOperationMode.MERGE_ON_READ.modeName(), + TableProperties.FORMAT_VERSION, + type.equals("dv") ? 3 : 2); + + this.table = Spark3Util.loadIcebergTable(spark, TABLE_NAME); + } + + private void dropTable() { + sql("DROP TABLE IF EXISTS %s PURGE", TABLE_NAME); + } + + private String newWarehouseDir() { + return hadoopConf.get("hadoop.tmp.dir") + UUID.randomUUID(); + } + + @FormatMethod + private void sql(@FormatString String query, Object... args) { + spark.sql(String.format(query, args)); + } +} diff --git a/spark/v4.0/spark-extensions/src/jmh/java/org/apache/iceberg/spark/MergeCardinalityCheckBenchmark.java b/spark/v4.0/spark-extensions/src/jmh/java/org/apache/iceberg/spark/MergeCardinalityCheckBenchmark.java new file mode 100644 index 000000000000..963daa2c364c --- /dev/null +++ b/spark/v4.0/spark-extensions/src/jmh/java/org/apache/iceberg/spark/MergeCardinalityCheckBenchmark.java @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.apache.spark.sql.functions.current_date; +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.expr; + +import com.google.errorprone.annotations.FormatMethod; +import com.google.errorprone.annotations.FormatString; +import java.util.UUID; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.DistributionMode; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.types.StructType; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Warmup; + +/** + * A benchmark that evaluates the performance of the cardinality check in MERGE operations. + * + *

To run this benchmark for spark-4.0: + * ./gradlew -DsparkVersions=4.0 :iceberg-spark:iceberg-spark-extensions-4.0_2.13:jmh + * -PjmhIncludeRegex=MergeCardinalityCheckBenchmark + * -PjmhOutputPath=benchmark/iceberg-merge-cardinality-check-benchmark.txt + * + */ +@Fork(1) +@State(Scope.Benchmark) +@Warmup(iterations = 3) +@Measurement(iterations = 5) +@BenchmarkMode(Mode.SingleShotTime) +public class MergeCardinalityCheckBenchmark { + + private static final String TABLE_NAME = "test_table"; + private static final int NUM_FILES = 5; + private static final int NUM_ROWS_PER_FILE = 1_000_000; + private static final int NUM_UNMATCHED_RECORDS_PER_MERGE = 100_000; + + private final Configuration hadoopConf = new Configuration(); + private SparkSession spark; + private long originalSnapshotId; + + @Setup + public void setupBenchmark() throws NoSuchTableException, ParseException { + setupSpark(); + initTable(); + appendData(); + + Table table = Spark3Util.loadIcebergTable(spark, TABLE_NAME); + this.originalSnapshotId = table.currentSnapshot().snapshotId(); + } + + @TearDown + public void tearDownBenchmark() { + tearDownSpark(); + dropTable(); + } + + @Benchmark + @Threads(1) + public void copyOnWriteMergeCardinalityCheck10PercentUpdates() { + runBenchmark(RowLevelOperationMode.COPY_ON_WRITE, 0.1); + } + + @Benchmark + @Threads(1) + public void copyOnWriteMergeCardinalityCheck30PercentUpdates() { + runBenchmark(RowLevelOperationMode.COPY_ON_WRITE, 0.3); + } + + @Benchmark + @Threads(1) + public void copyOnWriteMergeCardinalityCheck90PercentUpdates() { + runBenchmark(RowLevelOperationMode.COPY_ON_WRITE, 0.9); + } + + @Benchmark + @Threads(1) + public void mergeOnReadMergeCardinalityCheck10PercentUpdates() { + runBenchmark(RowLevelOperationMode.MERGE_ON_READ, 0.1); + } + + @Benchmark + @Threads(1) + public void mergeOnReadMergeCardinalityCheck30PercentUpdates() { + runBenchmark(RowLevelOperationMode.MERGE_ON_READ, 0.3); + } + + @Benchmark + @Threads(1) + public void mergeOnReadMergeCardinalityCheck90PercentUpdates() { + runBenchmark(RowLevelOperationMode.MERGE_ON_READ, 0.9); + } + + private void runBenchmark(RowLevelOperationMode mode, double updatePercentage) { + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' '%s')", + TABLE_NAME, TableProperties.MERGE_MODE, mode.modeName()); + + Dataset insertDataDF = spark.range(-NUM_UNMATCHED_RECORDS_PER_MERGE, 0, 1); + Dataset updateDataDF = spark.range((long) (updatePercentage * NUM_ROWS_PER_FILE)); + Dataset sourceDF = updateDataDF.union(insertDataDF); + sourceDF.createOrReplaceTempView("source"); + + sql( + "MERGE INTO %s t USING source s " + + "ON t.id = s.id " + + "WHEN MATCHED THEN " + + " UPDATE SET stringCol = 'invalid' " + + "WHEN NOT MATCHED THEN " + + " INSERT (id, intCol, floatCol, doubleCol, decimalCol, dateCol, timestampCol, stringCol) " + + " VALUES (s.id, null, null, null, null, null, null, 'new')", + TABLE_NAME); + + sql( + "CALL system.rollback_to_snapshot(table => '%s', snapshot_id => %dL)", + TABLE_NAME, originalSnapshotId); + } + + private void setupSpark() { + this.spark = + SparkSession.builder() + .config("spark.ui.enabled", false) + .config("spark.sql.extensions", IcebergSparkSessionExtensions.class.getName()) + .config("spark.sql.catalog.spark_catalog", SparkSessionCatalog.class.getName()) + .config("spark.sql.catalog.spark_catalog.type", "hadoop") + .config("spark.sql.catalog.spark_catalog.warehouse", newWarehouseDir()) + .config(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED().key(), "false") + .config(SQLConf.RUNTIME_ROW_LEVEL_OPERATION_GROUP_FILTER_ENABLED().key(), "false") + .config(SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), "false") + .config(SQLConf.SHUFFLE_PARTITIONS().key(), "2") + .master("local") + .getOrCreate(); + } + + private void tearDownSpark() { + spark.stop(); + } + + private void initTable() { + sql( + "CREATE TABLE %s ( " + + " id LONG, intCol INT, floatCol FLOAT, doubleCol DOUBLE, " + + " decimalCol DECIMAL(20, 5), dateCol DATE, timestampCol TIMESTAMP, " + + " stringCol STRING)" + + "USING iceberg " + + "TBLPROPERTIES (" + + " '%s' '%s'," + + " '%s' '%d'," + + " '%s' '%d')", + TABLE_NAME, + TableProperties.MERGE_DISTRIBUTION_MODE, + DistributionMode.NONE.modeName(), + TableProperties.SPLIT_OPEN_FILE_COST, + Integer.MAX_VALUE, + TableProperties.FORMAT_VERSION, + 2); + + sql("ALTER TABLE %s WRITE ORDERED BY id", TABLE_NAME); + } + + private void dropTable() { + sql("DROP TABLE IF EXISTS %s PURGE", TABLE_NAME); + } + + private void appendData() throws NoSuchTableException { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + Dataset inputDF = + spark + .range(NUM_ROWS_PER_FILE) + .withColumn("intCol", expr("CAST(id AS INT)")) + .withColumn("floatCol", expr("CAST(id AS FLOAT)")) + .withColumn("doubleCol", expr("CAST(id AS DOUBLE)")) + .withColumn("decimalCol", expr("CAST(id AS DECIMAL(20, 5))")) + .withColumn("dateCol", date_add(current_date(), fileNum)) + .withColumn("timestampCol", expr("TO_TIMESTAMP(dateCol)")) + .withColumn("stringCol", expr("CAST(dateCol AS STRING)")); + appendAsFile(inputDF); + } + } + + private void appendAsFile(Dataset df) throws NoSuchTableException { + // ensure the schema is precise (including nullability) + StructType sparkSchema = spark.table(TABLE_NAME).schema(); + spark.createDataFrame(df.rdd(), sparkSchema).coalesce(1).writeTo(TABLE_NAME).append(); + } + + private String newWarehouseDir() { + return hadoopConf.get("hadoop.tmp.dir") + UUID.randomUUID(); + } + + @FormatMethod + private void sql(@FormatString String query, Object... args) { + spark.sql(String.format(query, args)); + } +} diff --git a/spark/v4.0/spark-extensions/src/jmh/java/org/apache/iceberg/spark/PlanningBenchmark.java b/spark/v4.0/spark-extensions/src/jmh/java/org/apache/iceberg/spark/PlanningBenchmark.java new file mode 100644 index 000000000000..84693e7986c0 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/jmh/java/org/apache/iceberg/spark/PlanningBenchmark.java @@ -0,0 +1,409 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.apache.iceberg.PlanningMode.DISTRIBUTED; +import static org.apache.iceberg.PlanningMode.LOCAL; + +import com.google.errorprone.annotations.FormatMethod; +import com.google.errorprone.annotations.FormatString; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.BatchScan; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileGenerationUtil; +import org.apache.iceberg.PlanningMode; +import org.apache.iceberg.RowDelta; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.ScanTask; +import org.apache.iceberg.SparkDistributedDataScan; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.TestHelpers; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions; +import org.apache.iceberg.types.Conversions; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Timeout; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +/** + * A benchmark that evaluates the job planning performance. + * + *

To run this benchmark for spark-4.0: + * ./gradlew -DsparkVersions=4.0 :iceberg-spark:iceberg-spark-extensions-4.0_2.13:jmh + * -PjmhIncludeRegex=PlanningBenchmark + * -PjmhOutputPath=benchmark/iceberg-planning-benchmark.txt + * + */ +@Fork(1) +@State(Scope.Benchmark) +@Warmup(iterations = 3) +@Measurement(iterations = 5) +@Timeout(time = 20, timeUnit = TimeUnit.MINUTES) +@BenchmarkMode(Mode.SingleShotTime) +public class PlanningBenchmark { + + private static final String TABLE_NAME = "test_table"; + private static final String PARTITION_COLUMN = "ss_ticket_number"; + private static final int PARTITION_VALUE = 10; + private static final String SORT_KEY_COLUMN = "ss_sold_date_sk"; + private static final int SORT_KEY_VALUE = 5; + + private static final Expression SORT_KEY_PREDICATE = + Expressions.equal(SORT_KEY_COLUMN, SORT_KEY_VALUE); + private static final Expression PARTITION_PREDICATE = + Expressions.equal(PARTITION_COLUMN, PARTITION_VALUE); + private static final Expression PARTITION_AND_SORT_KEY_PREDICATE = + Expressions.and(PARTITION_PREDICATE, SORT_KEY_PREDICATE); + + private static final int NUM_PARTITIONS = 30; + private static final int NUM_DATA_FILES_PER_PARTITION = 50_000; + private static final int NUM_DELETE_FILES_PER_PARTITION = 50; + + private final Configuration hadoopConf = new Configuration(); + private SparkSession spark; + private Table table; + + @Param({"partition", "file", "dv"}) + private String type; + + @Setup + public void setupBenchmark() throws NoSuchTableException, ParseException { + setupSpark(); + initTable(); + initDataAndDeletes(); + } + + @TearDown + public void tearDownBenchmark() { + dropTable(); + tearDownSpark(); + } + + @Benchmark + @Threads(1) + public void localPlanningWithPartitionAndMinMaxFilter(Blackhole blackhole) { + BatchScan scan = table.newBatchScan(); + List fileTasks = planFilesWithoutColumnStats(scan, PARTITION_AND_SORT_KEY_PREDICATE); + blackhole.consume(fileTasks); + } + + @Benchmark + @Threads(1) + public void distributedPlanningWithPartitionAndMinMaxFilter(Blackhole blackhole) { + BatchScan scan = newDistributedScan(DISTRIBUTED, DISTRIBUTED); + List fileTasks = planFilesWithoutColumnStats(scan, PARTITION_AND_SORT_KEY_PREDICATE); + blackhole.consume(fileTasks); + } + + @Benchmark + @Threads(1) + public void localPlanningWithMinMaxFilter(Blackhole blackhole) { + BatchScan scan = table.newBatchScan(); + List fileTasks = planFilesWithoutColumnStats(scan, SORT_KEY_PREDICATE); + blackhole.consume(fileTasks); + } + + @Benchmark + @Threads(1) + public void distributedPlanningWithMinMaxFilter(Blackhole blackhole) { + BatchScan scan = newDistributedScan(DISTRIBUTED, DISTRIBUTED); + List fileTasks = planFilesWithoutColumnStats(scan, SORT_KEY_PREDICATE); + blackhole.consume(fileTasks); + } + + @Benchmark + @Threads(1) + public void localPlanningWithoutFilter(Blackhole blackhole) { + BatchScan scan = table.newBatchScan(); + List fileTasks = planFilesWithoutColumnStats(scan, Expressions.alwaysTrue()); + blackhole.consume(fileTasks); + } + + @Benchmark + @Threads(1) + public void distributedPlanningWithoutFilter(Blackhole blackhole) { + BatchScan scan = newDistributedScan(DISTRIBUTED, DISTRIBUTED); + List fileTasks = planFilesWithoutColumnStats(scan, Expressions.alwaysTrue()); + blackhole.consume(fileTasks); + } + + @Benchmark + @Threads(1) + public void localPlanningWithoutFilterWithStats(Blackhole blackhole) { + BatchScan scan = table.newBatchScan(); + List fileTasks = planFilesWithColumnStats(scan, Expressions.alwaysTrue()); + blackhole.consume(fileTasks); + } + + @Benchmark + @Threads(1) + public void distributedPlanningWithoutFilterWithStats(Blackhole blackhole) { + BatchScan scan = newDistributedScan(DISTRIBUTED, DISTRIBUTED); + List fileTasks = planFilesWithColumnStats(scan, Expressions.alwaysTrue()); + blackhole.consume(fileTasks); + } + + @Benchmark + @Threads(1) + public void distributedDataLocalDeletesPlanningWithoutFilterWithStats(Blackhole blackhole) { + BatchScan scan = newDistributedScan(DISTRIBUTED, LOCAL); + List fileTasks = planFilesWithColumnStats(scan, Expressions.alwaysTrue()); + blackhole.consume(fileTasks); + } + + @Benchmark + @Threads(1) + public void localDataDistributedDeletesPlanningWithoutFilterWithStats(Blackhole blackhole) { + BatchScan scan = newDistributedScan(LOCAL, DISTRIBUTED); + List fileTasks = planFilesWithColumnStats(scan, Expressions.alwaysTrue()); + blackhole.consume(fileTasks); + } + + @Benchmark + @Threads(1) + public void localPlanningViaDistributedScanWithoutFilterWithStats(Blackhole blackhole) { + BatchScan scan = newDistributedScan(LOCAL, LOCAL); + List fileTasks = planFilesWithColumnStats(scan, Expressions.alwaysTrue()); + blackhole.consume(fileTasks); + } + + private void setupSpark() { + this.spark = + SparkSession.builder() + .config("spark.ui.enabled", false) + .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .config("spark.driver.maxResultSize", "8G") + .config("spark.sql.extensions", IcebergSparkSessionExtensions.class.getName()) + .config("spark.sql.catalog.spark_catalog", SparkSessionCatalog.class.getName()) + .config("spark.sql.catalog.spark_catalog.type", "hadoop") + .config("spark.sql.catalog.spark_catalog.warehouse", newWarehouseDir()) + .master("local[*]") + .getOrCreate(); + } + + private void tearDownSpark() { + spark.stop(); + } + + private void initTable() throws NoSuchTableException, ParseException { + sql( + "CREATE TABLE %s ( " + + " `ss_sold_date_sk` INT, " + + " `ss_sold_time_sk` INT, " + + " `ss_item_sk` INT, " + + " `ss_customer_sk` STRING, " + + " `ss_cdemo_sk` STRING, " + + " `ss_hdemo_sk` STRING, " + + " `ss_addr_sk` STRING, " + + " `ss_store_sk` STRING, " + + " `ss_promo_sk` STRING, " + + " `ss_ticket_number` INT, " + + " `ss_quantity` STRING, " + + " `ss_wholesale_cost` STRING, " + + " `ss_list_price` STRING, " + + " `ss_sales_price` STRING, " + + " `ss_ext_discount_amt` STRING, " + + " `ss_ext_sales_price` STRING, " + + " `ss_ext_wholesale_cost` STRING, " + + " `ss_ext_list_price` STRING, " + + " `ss_ext_tax` STRING, " + + " `ss_coupon_amt` STRING, " + + " `ss_net_paid` STRING, " + + " `ss_net_paid_inc_tax` STRING, " + + " `ss_net_profit` STRING " + + ")" + + "USING iceberg " + + "PARTITIONED BY (%s) " + + "TBLPROPERTIES (" + + " '%s' '%b'," + + " '%s' '%s'," + + " '%s' '%d')", + TABLE_NAME, + PARTITION_COLUMN, + TableProperties.MANIFEST_MERGE_ENABLED, + false, + TableProperties.DELETE_MODE, + RowLevelOperationMode.MERGE_ON_READ.modeName(), + TableProperties.FORMAT_VERSION, + type.equals("dv") ? 3 : 2); + + this.table = Spark3Util.loadIcebergTable(spark, TABLE_NAME); + } + + private void dropTable() { + sql("DROP TABLE IF EXISTS %s PURGE", TABLE_NAME); + } + + private void initDataAndDeletes() { + if (type.equals("partition")) { + initDataAndPartitionScopedDeletes(); + } else if (type.equals("file")) { + initDataAndFileScopedDeletes(); + } else { + initDataAndDVs(); + } + } + + private void initDataAndPartitionScopedDeletes() { + for (int partitionOrdinal = 0; partitionOrdinal < NUM_PARTITIONS; partitionOrdinal++) { + StructLike partition = TestHelpers.Row.of(partitionOrdinal); + + RowDelta rowDelta = table.newRowDelta(); + + for (int fileOrdinal = 0; fileOrdinal < NUM_DATA_FILES_PER_PARTITION; fileOrdinal++) { + DataFile dataFile = generateDataFile(partition, Integer.MIN_VALUE, Integer.MIN_VALUE); + rowDelta.addRows(dataFile); + } + + // add one data file that would match the sort key predicate + DataFile sortKeyDataFile = generateDataFile(partition, SORT_KEY_VALUE, SORT_KEY_VALUE); + rowDelta.addRows(sortKeyDataFile); + + for (int fileOrdinal = 0; fileOrdinal < NUM_DELETE_FILES_PER_PARTITION; fileOrdinal++) { + DeleteFile deleteFile = FileGenerationUtil.generatePositionDeleteFile(table, partition); + rowDelta.addDeletes(deleteFile); + } + + rowDelta.commit(); + } + } + + private void initDataAndFileScopedDeletes() { + for (int partitionOrdinal = 0; partitionOrdinal < NUM_PARTITIONS; partitionOrdinal++) { + StructLike partition = TestHelpers.Row.of(partitionOrdinal); + + RowDelta rowDelta = table.newRowDelta(); + + for (int fileOrdinal = 0; fileOrdinal < NUM_DATA_FILES_PER_PARTITION; fileOrdinal++) { + DataFile dataFile = generateDataFile(partition, Integer.MIN_VALUE, Integer.MIN_VALUE); + DeleteFile deleteFile = FileGenerationUtil.generatePositionDeleteFile(table, dataFile); + rowDelta.addRows(dataFile); + rowDelta.addDeletes(deleteFile); + } + + // add one data file that would match the sort key predicate + DataFile sortKeyDataFile = generateDataFile(partition, SORT_KEY_VALUE, SORT_KEY_VALUE); + rowDelta.addRows(sortKeyDataFile); + + rowDelta.commit(); + } + } + + private void initDataAndDVs() { + for (int partitionOrdinal = 0; partitionOrdinal < NUM_PARTITIONS; partitionOrdinal++) { + StructLike partition = TestHelpers.Row.of(partitionOrdinal); + + RowDelta rowDelta = table.newRowDelta(); + + for (int fileOrdinal = 0; fileOrdinal < NUM_DATA_FILES_PER_PARTITION; fileOrdinal++) { + DataFile dataFile = generateDataFile(partition, Integer.MIN_VALUE, Integer.MIN_VALUE); + DeleteFile dv = FileGenerationUtil.generateDV(table, dataFile); + rowDelta.addRows(dataFile); + rowDelta.addDeletes(dv); + } + + // add one data file that would match the sort key predicate + DataFile sortKeyDataFile = generateDataFile(partition, SORT_KEY_VALUE, SORT_KEY_VALUE); + rowDelta.addRows(sortKeyDataFile); + + rowDelta.commit(); + } + } + + private DataFile generateDataFile(StructLike partition, int sortKeyMin, int sortKeyMax) { + int sortKeyFieldId = table.schema().findField(SORT_KEY_COLUMN).fieldId(); + ByteBuffer lower = Conversions.toByteBuffer(Types.IntegerType.get(), sortKeyMin); + Map lowerBounds = ImmutableMap.of(sortKeyFieldId, lower); + ByteBuffer upper = Conversions.toByteBuffer(Types.IntegerType.get(), sortKeyMax); + Map upperBounds = ImmutableMap.of(sortKeyFieldId, upper); + return FileGenerationUtil.generateDataFile(table, partition, lowerBounds, upperBounds); + } + + private String newWarehouseDir() { + return hadoopConf.get("hadoop.tmp.dir") + UUID.randomUUID(); + } + + private List planFilesWithoutColumnStats(BatchScan scan, Expression predicate) { + return planFiles(scan, predicate, false); + } + + private List planFilesWithColumnStats(BatchScan scan, Expression predicate) { + return planFiles(scan, predicate, true); + } + + private List planFiles(BatchScan scan, Expression predicate, boolean withColumnStats) { + table.refresh(); + + BatchScan configuredScan = scan.filter(predicate); + + if (withColumnStats) { + configuredScan = scan.includeColumnStats(); + } + + try (CloseableIterable fileTasks = configuredScan.planFiles()) { + return Lists.newArrayList(fileTasks); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private BatchScan newDistributedScan(PlanningMode dataMode, PlanningMode deleteMode) { + table + .updateProperties() + .set(TableProperties.DATA_PLANNING_MODE, dataMode.modeName()) + .set(TableProperties.DELETE_PLANNING_MODE, deleteMode.modeName()) + .commit(); + SparkReadConf readConf = new SparkReadConf(spark, table, ImmutableMap.of()); + return new SparkDistributedDataScan(spark, table, readConf); + } + + @FormatMethod + private void sql(@FormatString String query, Object... args) { + spark.sql(String.format(query, args)); + } +} diff --git a/spark/v4.0/spark-extensions/src/jmh/java/org/apache/iceberg/spark/TaskGroupPlanningBenchmark.java b/spark/v4.0/spark-extensions/src/jmh/java/org/apache/iceberg/spark/TaskGroupPlanningBenchmark.java new file mode 100644 index 000000000000..7c2def237874 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/jmh/java/org/apache/iceberg/spark/TaskGroupPlanningBenchmark.java @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import com.google.errorprone.annotations.FormatMethod; +import com.google.errorprone.annotations.FormatString; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileGenerationUtil; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Partitioning; +import org.apache.iceberg.RowDelta; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.TestHelpers; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions; +import org.apache.iceberg.util.TableScanUtil; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Timeout; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +/** + * A benchmark that evaluates the task group planning performance. + * + *

To run this benchmark for spark-4.0: + * ./gradlew -DsparkVersions=4.0 :iceberg-spark:iceberg-spark-extensions-4.0_2.13:jmh + * -PjmhIncludeRegex=TaskGroupPlanningBenchmark + * -PjmhOutputPath=benchmark/iceberg-task-group-planning-benchmark.txt + * + */ +@Fork(1) +@State(Scope.Benchmark) +@Warmup(iterations = 3) +@Measurement(iterations = 5) +@Timeout(time = 30, timeUnit = TimeUnit.MINUTES) +@BenchmarkMode(Mode.SingleShotTime) +public class TaskGroupPlanningBenchmark { + + private static final String TABLE_NAME = "test_table"; + private static final String PARTITION_COLUMN = "ss_ticket_number"; + + private static final int NUM_PARTITIONS = 150; + private static final int NUM_DATA_FILES_PER_PARTITION = 50_000; + private static final int NUM_DELETE_FILES_PER_PARTITION = 25; + + private final Configuration hadoopConf = new Configuration(); + private SparkSession spark; + private Table table; + + private List fileTasks; + + @Setup + public void setupBenchmark() throws NoSuchTableException, ParseException { + setupSpark(); + initTable(); + initDataAndDeletes(); + loadFileTasks(); + } + + @TearDown + public void tearDownBenchmark() { + dropTable(); + tearDownSpark(); + } + + @Benchmark + @Threads(1) + public void planTaskGroups(Blackhole blackhole) { + SparkReadConf readConf = new SparkReadConf(spark, table, ImmutableMap.of()); + List> taskGroups = + TableScanUtil.planTaskGroups( + fileTasks, + readConf.splitSize(), + readConf.splitLookback(), + readConf.splitOpenFileCost()); + + long rowsCount = 0L; + for (ScanTaskGroup taskGroup : taskGroups) { + rowsCount += taskGroup.estimatedRowsCount(); + } + blackhole.consume(rowsCount); + + long filesCount = 0L; + for (ScanTaskGroup taskGroup : taskGroups) { + filesCount += taskGroup.filesCount(); + } + blackhole.consume(filesCount); + + long sizeBytes = 0L; + for (ScanTaskGroup taskGroup : taskGroups) { + sizeBytes += taskGroup.sizeBytes(); + } + blackhole.consume(sizeBytes); + } + + @Benchmark + @Threads(1) + public void planTaskGroupsWithGrouping(Blackhole blackhole) { + SparkReadConf readConf = new SparkReadConf(spark, table, ImmutableMap.of()); + + List> taskGroups = + TableScanUtil.planTaskGroups( + fileTasks, + readConf.splitSize(), + readConf.splitLookback(), + readConf.splitOpenFileCost(), + Partitioning.groupingKeyType(table.schema(), table.specs().values())); + + long rowsCount = 0L; + for (ScanTaskGroup taskGroup : taskGroups) { + rowsCount += taskGroup.estimatedRowsCount(); + } + blackhole.consume(rowsCount); + + long filesCount = 0L; + for (ScanTaskGroup taskGroup : taskGroups) { + filesCount += taskGroup.filesCount(); + } + blackhole.consume(filesCount); + + long sizeBytes = 0L; + for (ScanTaskGroup taskGroup : taskGroups) { + sizeBytes += taskGroup.sizeBytes(); + } + blackhole.consume(sizeBytes); + } + + private void loadFileTasks() { + table.refresh(); + + try (CloseableIterable fileTasksIterable = table.newScan().planFiles()) { + this.fileTasks = Lists.newArrayList(fileTasksIterable); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private void initDataAndDeletes() { + for (int partitionOrdinal = 0; partitionOrdinal < NUM_PARTITIONS; partitionOrdinal++) { + StructLike partition = TestHelpers.Row.of(partitionOrdinal); + + RowDelta rowDelta = table.newRowDelta(); + + for (int fileOrdinal = 0; fileOrdinal < NUM_DATA_FILES_PER_PARTITION; fileOrdinal++) { + DataFile dataFile = FileGenerationUtil.generateDataFile(table, partition); + rowDelta.addRows(dataFile); + } + + for (int fileOrdinal = 0; fileOrdinal < NUM_DELETE_FILES_PER_PARTITION; fileOrdinal++) { + DeleteFile deleteFile = FileGenerationUtil.generatePositionDeleteFile(table, partition); + rowDelta.addDeletes(deleteFile); + } + + rowDelta.commit(); + } + } + + private void setupSpark() { + this.spark = + SparkSession.builder() + .config("spark.ui.enabled", false) + .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .config("spark.sql.extensions", IcebergSparkSessionExtensions.class.getName()) + .config("spark.sql.catalog.spark_catalog", SparkSessionCatalog.class.getName()) + .config("spark.sql.catalog.spark_catalog.type", "hadoop") + .config("spark.sql.catalog.spark_catalog.warehouse", newWarehouseDir()) + .master("local[*]") + .getOrCreate(); + } + + private void tearDownSpark() { + spark.stop(); + } + + private void initTable() throws NoSuchTableException, ParseException { + sql( + "CREATE TABLE %s ( " + + " `ss_sold_date_sk` INT, " + + " `ss_sold_time_sk` INT, " + + " `ss_item_sk` INT, " + + " `ss_customer_sk` STRING, " + + " `ss_cdemo_sk` STRING, " + + " `ss_hdemo_sk` STRING, " + + " `ss_addr_sk` STRING, " + + " `ss_store_sk` STRING, " + + " `ss_promo_sk` STRING, " + + " `ss_ticket_number` INT, " + + " `ss_quantity` STRING, " + + " `ss_wholesale_cost` STRING, " + + " `ss_list_price` STRING, " + + " `ss_sales_price` STRING, " + + " `ss_ext_discount_amt` STRING, " + + " `ss_ext_sales_price` STRING, " + + " `ss_ext_wholesale_cost` STRING, " + + " `ss_ext_list_price` STRING, " + + " `ss_ext_tax` STRING, " + + " `ss_coupon_amt` STRING, " + + " `ss_net_paid` STRING, " + + " `ss_net_paid_inc_tax` STRING, " + + " `ss_net_profit` STRING " + + ")" + + "USING iceberg " + + "PARTITIONED BY (%s) " + + "TBLPROPERTIES (" + + " '%s' '%b'," + + " '%s' '%s'," + + " '%s' '%d')", + TABLE_NAME, + PARTITION_COLUMN, + TableProperties.MANIFEST_MERGE_ENABLED, + false, + TableProperties.DELETE_MODE, + RowLevelOperationMode.MERGE_ON_READ.modeName(), + TableProperties.FORMAT_VERSION, + 2); + + this.table = Spark3Util.loadIcebergTable(spark, TABLE_NAME); + } + + private void dropTable() { + sql("DROP TABLE IF EXISTS %s PURGE", TABLE_NAME); + } + + private String newWarehouseDir() { + return hadoopConf.get("hadoop.tmp.dir") + UUID.randomUUID(); + } + + @FormatMethod + private void sql(@FormatString String query, Object... args) { + spark.sql(String.format(query, args)); + } +} diff --git a/spark/v4.0/spark-extensions/src/jmh/java/org/apache/iceberg/spark/UpdateProjectionBenchmark.java b/spark/v4.0/spark-extensions/src/jmh/java/org/apache/iceberg/spark/UpdateProjectionBenchmark.java new file mode 100644 index 000000000000..d917eae5eb0f --- /dev/null +++ b/spark/v4.0/spark-extensions/src/jmh/java/org/apache/iceberg/spark/UpdateProjectionBenchmark.java @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.apache.spark.sql.functions.current_date; +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.expr; + +import com.google.errorprone.annotations.FormatMethod; +import com.google.errorprone.annotations.FormatString; +import java.util.UUID; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.DistributionMode; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.types.StructType; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Warmup; + +@Fork(1) +@State(Scope.Benchmark) +@Warmup(iterations = 3) +@Measurement(iterations = 5) +@BenchmarkMode(Mode.SingleShotTime) +public class UpdateProjectionBenchmark { + + private static final String TABLE_NAME = "test_table"; + private static final int NUM_FILES = 5; + private static final int NUM_ROWS_PER_FILE = 1_000_000; + + private final Configuration hadoopConf = new Configuration(); + private SparkSession spark; + private long originalSnapshotId; + + @Setup + public void setupBenchmark() throws NoSuchTableException, ParseException { + setupSpark(); + initTable(); + appendData(); + + Table table = Spark3Util.loadIcebergTable(spark, TABLE_NAME); + this.originalSnapshotId = table.currentSnapshot().snapshotId(); + } + + @TearDown + public void tearDownBenchmark() { + tearDownSpark(); + dropTable(); + } + + @Benchmark + @Threads(1) + public void copyOnWriteUpdate10Percent() { + runBenchmark(RowLevelOperationMode.COPY_ON_WRITE, 0.1); + } + + @Benchmark + @Threads(1) + public void copyOnWriteUpdate30Percent() { + runBenchmark(RowLevelOperationMode.COPY_ON_WRITE, 0.3); + } + + @Benchmark + @Threads(1) + public void copyOnWriteUpdate75Percent() { + runBenchmark(RowLevelOperationMode.COPY_ON_WRITE, 0.75); + } + + @Benchmark + @Threads(1) + public void mergeOnRead10Percent() { + runBenchmark(RowLevelOperationMode.MERGE_ON_READ, 0.1); + } + + @Benchmark + @Threads(1) + public void mergeOnReadUpdate30Percent() { + runBenchmark(RowLevelOperationMode.MERGE_ON_READ, 0.3); + } + + @Benchmark + @Threads(1) + public void mergeOnReadUpdate75Percent() { + runBenchmark(RowLevelOperationMode.MERGE_ON_READ, 0.75); + } + + private void runBenchmark(RowLevelOperationMode mode, double updatePercentage) { + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' '%s')", + TABLE_NAME, TableProperties.UPDATE_MODE, mode.modeName()); + + int mod = (int) (NUM_ROWS_PER_FILE / (NUM_ROWS_PER_FILE * updatePercentage)); + + sql( + "UPDATE %s " + + "SET intCol = intCol + 10, dateCol = date_add(dateCol, 1) " + + "WHERE mod(id, %d) = 0", + TABLE_NAME, mod); + + sql( + "CALL system.rollback_to_snapshot(table => '%s', snapshot_id => %dL)", + TABLE_NAME, originalSnapshotId); + } + + private void setupSpark() { + this.spark = + SparkSession.builder() + .config("spark.ui.enabled", false) + .config("spark.sql.extensions", IcebergSparkSessionExtensions.class.getName()) + .config("spark.sql.catalog.spark_catalog", SparkSessionCatalog.class.getName()) + .config("spark.sql.catalog.spark_catalog.type", "hadoop") + .config("spark.sql.catalog.spark_catalog.warehouse", newWarehouseDir()) + .config(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED().key(), "false") + .config(SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), "false") + .config(SQLConf.SHUFFLE_PARTITIONS().key(), "2") + .master("local") + .getOrCreate(); + } + + private void tearDownSpark() { + spark.stop(); + } + + private void initTable() { + sql( + "CREATE TABLE %s ( " + + " id LONG, intCol INT, floatCol FLOAT, doubleCol DOUBLE, " + + " decimalCol DECIMAL(20, 5), dateCol DATE, timestampCol TIMESTAMP, " + + " stringCol STRING)" + + "USING iceberg " + + "TBLPROPERTIES (" + + " '%s' '%s'," + + " '%s' '%d'," + + " '%s' '%d')", + TABLE_NAME, + TableProperties.UPDATE_DISTRIBUTION_MODE, + DistributionMode.NONE.modeName(), + TableProperties.SPLIT_OPEN_FILE_COST, + Integer.MAX_VALUE, + TableProperties.FORMAT_VERSION, + 2); + + sql("ALTER TABLE %s WRITE ORDERED BY id", TABLE_NAME); + } + + private void dropTable() { + sql("DROP TABLE IF EXISTS %s PURGE", TABLE_NAME); + } + + private void appendData() throws NoSuchTableException { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + Dataset inputDF = + spark + .range(NUM_ROWS_PER_FILE) + .withColumn("intCol", expr("CAST(id AS INT)")) + .withColumn("floatCol", expr("CAST(id AS FLOAT)")) + .withColumn("doubleCol", expr("CAST(id AS DOUBLE)")) + .withColumn("decimalCol", expr("CAST(id AS DECIMAL(20, 5))")) + .withColumn("dateCol", date_add(current_date(), fileNum)) + .withColumn("timestampCol", expr("TO_TIMESTAMP(dateCol)")) + .withColumn("stringCol", expr("CAST(dateCol AS STRING)")); + appendAsFile(inputDF); + } + } + + private void appendAsFile(Dataset df) throws NoSuchTableException { + // ensure the schema is precise (including nullability) + StructType sparkSchema = spark.table(TABLE_NAME).schema(); + spark.createDataFrame(df.rdd(), sparkSchema).coalesce(1).writeTo(TABLE_NAME).append(); + } + + private String newWarehouseDir() { + return hadoopConf.get("hadoop.tmp.dir") + UUID.randomUUID(); + } + + @FormatMethod + private void sql(@FormatString String query, Object... args) { + spark.sql(String.format(query, args)); + } +} diff --git a/spark/v4.0/spark-extensions/src/main/antlr/org.apache.spark.sql.catalyst.parser.extensions/IcebergSqlExtensions.g4 b/spark/v4.0/spark-extensions/src/main/antlr/org.apache.spark.sql.catalyst.parser.extensions/IcebergSqlExtensions.g4 new file mode 100644 index 000000000000..b962699d9b47 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/antlr/org.apache.spark.sql.catalyst.parser.extensions/IcebergSqlExtensions.g4 @@ -0,0 +1,374 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * This file is an adaptation of Presto's and Spark's grammar files. + */ + +grammar IcebergSqlExtensions; + +@lexer::members { + /** + * Verify whether current token is a valid decimal token (which contains dot). + * Returns true if the character that follows the token is not a digit or letter or underscore. + * + * For example: + * For char stream "2.3", "2." is not a valid decimal token, because it is followed by digit '3'. + * For char stream "2.3_", "2.3" is not a valid decimal token, because it is followed by '_'. + * For char stream "2.3W", "2.3" is not a valid decimal token, because it is followed by 'W'. + * For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is followed + * by a space. 34.E2 is a valid decimal token because it is followed by symbol '+' + * which is not a digit or letter or underscore. + */ + public boolean isValidDecimal() { + int nextChar = _input.LA(1); + if (nextChar >= 'A' && nextChar <= 'Z' || nextChar >= '0' && nextChar <= '9' || + nextChar == '_') { + return false; + } else { + return true; + } + } + + /** + * This method will be called when we see '/*' and try to match it as a bracketed comment. + * If the next character is '+', it should be parsed as hint later, and we cannot match + * it as a bracketed comment. + * + * Returns true if the next character is '+'. + */ + public boolean isHint() { + int nextChar = _input.LA(1); + if (nextChar == '+') { + return true; + } else { + return false; + } + } +} + +singleStatement + : statement EOF + ; + +statement + : CALL multipartIdentifier '(' (callArgument (',' callArgument)*)? ')' #call + | ALTER TABLE multipartIdentifier ADD PARTITION FIELD transform (AS name=identifier)? #addPartitionField + | ALTER TABLE multipartIdentifier DROP PARTITION FIELD transform #dropPartitionField + | ALTER TABLE multipartIdentifier REPLACE PARTITION FIELD transform WITH transform (AS name=identifier)? #replacePartitionField + | ALTER TABLE multipartIdentifier WRITE writeSpec #setWriteDistributionAndOrdering + | ALTER TABLE multipartIdentifier SET IDENTIFIER_KW FIELDS fieldList #setIdentifierFields + | ALTER TABLE multipartIdentifier DROP IDENTIFIER_KW FIELDS fieldList #dropIdentifierFields + | ALTER TABLE multipartIdentifier createReplaceBranchClause #createOrReplaceBranch + | ALTER TABLE multipartIdentifier createReplaceTagClause #createOrReplaceTag + | ALTER TABLE multipartIdentifier DROP BRANCH (IF EXISTS)? identifier #dropBranch + | ALTER TABLE multipartIdentifier DROP TAG (IF EXISTS)? identifier #dropTag + ; + +createReplaceTagClause + : (CREATE OR)? REPLACE TAG identifier tagOptions + | CREATE TAG (IF NOT EXISTS)? identifier tagOptions + ; + +createReplaceBranchClause + : (CREATE OR)? REPLACE BRANCH identifier branchOptions + | CREATE BRANCH (IF NOT EXISTS)? identifier branchOptions + ; + +tagOptions + : (AS OF VERSION snapshotId)? (refRetain)? + ; + +branchOptions + : (AS OF VERSION snapshotId)? (refRetain)? (snapshotRetention)? + ; + +snapshotRetention + : WITH SNAPSHOT RETENTION minSnapshotsToKeep + | WITH SNAPSHOT RETENTION maxSnapshotAge + | WITH SNAPSHOT RETENTION minSnapshotsToKeep maxSnapshotAge + ; + +refRetain + : RETAIN number timeUnit + ; + +maxSnapshotAge + : number timeUnit + ; + +minSnapshotsToKeep + : number SNAPSHOTS + ; + +writeSpec + : (writeDistributionSpec | writeOrderingSpec)* + ; + +writeDistributionSpec + : DISTRIBUTED BY PARTITION + ; + +writeOrderingSpec + : LOCALLY? ORDERED BY order + | UNORDERED + ; + +callArgument + : expression #positionalArgument + | identifier '=>' expression #namedArgument + ; + +singleOrder + : order EOF + ; + +order + : fields+=orderField (',' fields+=orderField)* + | '(' fields+=orderField (',' fields+=orderField)* ')' + ; + +orderField + : transform direction=(ASC | DESC)? (NULLS nullOrder=(FIRST | LAST))? + ; + +transform + : multipartIdentifier #identityTransform + | transformName=identifier + '(' arguments+=transformArgument (',' arguments+=transformArgument)* ')' #applyTransform + ; + +transformArgument + : multipartIdentifier + | constant + ; + +expression + : constant + | stringMap + | stringArray + ; + +constant + : number #numericLiteral + | booleanValue #booleanLiteral + | STRING+ #stringLiteral + | identifier STRING #typeConstructor + ; + +stringMap + : MAP '(' constant (',' constant)* ')' + ; + +stringArray + : ARRAY '(' constant (',' constant)* ')' + ; + +booleanValue + : TRUE | FALSE + ; + +number + : MINUS? EXPONENT_VALUE #exponentLiteral + | MINUS? DECIMAL_VALUE #decimalLiteral + | MINUS? INTEGER_VALUE #integerLiteral + | MINUS? BIGINT_LITERAL #bigIntLiteral + | MINUS? SMALLINT_LITERAL #smallIntLiteral + | MINUS? TINYINT_LITERAL #tinyIntLiteral + | MINUS? DOUBLE_LITERAL #doubleLiteral + | MINUS? FLOAT_LITERAL #floatLiteral + | MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral + ; + +multipartIdentifier + : parts+=identifier ('.' parts+=identifier)* + ; + +identifier + : IDENTIFIER #unquotedIdentifier + | quotedIdentifier #quotedIdentifierAlternative + | nonReserved #unquotedIdentifier + ; + +quotedIdentifier + : BACKQUOTED_IDENTIFIER + ; + +fieldList + : fields+=multipartIdentifier (',' fields+=multipartIdentifier)* + ; + +nonReserved + : ADD | ALTER | AS | ASC | BRANCH | BY | CALL | CREATE | DAYS | DESC | DROP | EXISTS | FIELD | FIRST | HOURS | IF | LAST | NOT | NULLS | OF | OR | ORDERED | PARTITION | TABLE | WRITE + | DISTRIBUTED | LOCALLY | MINUTES | MONTHS | UNORDERED | REPLACE | RETAIN | VERSION | WITH | IDENTIFIER_KW | FIELDS | SET | SNAPSHOT | SNAPSHOTS + | TAG | TRUE | FALSE + | MAP + ; + +snapshotId + : number + ; + +numSnapshots + : number + ; + +timeUnit + : DAYS + | HOURS + | MINUTES + ; + +ADD: 'ADD'; +ALTER: 'ALTER'; +AS: 'AS'; +ASC: 'ASC'; +BRANCH: 'BRANCH'; +BY: 'BY'; +CALL: 'CALL'; +DAYS: 'DAYS'; +DESC: 'DESC'; +DISTRIBUTED: 'DISTRIBUTED'; +DROP: 'DROP'; +EXISTS: 'EXISTS'; +FIELD: 'FIELD'; +FIELDS: 'FIELDS'; +FIRST: 'FIRST'; +HOURS: 'HOURS'; +IF : 'IF'; +LAST: 'LAST'; +LOCALLY: 'LOCALLY'; +MINUTES: 'MINUTES'; +MONTHS: 'MONTHS'; +CREATE: 'CREATE'; +NOT: 'NOT'; +NULLS: 'NULLS'; +OF: 'OF'; +OR: 'OR'; +ORDERED: 'ORDERED'; +PARTITION: 'PARTITION'; +REPLACE: 'REPLACE'; +RETAIN: 'RETAIN'; +RETENTION: 'RETENTION'; +IDENTIFIER_KW: 'IDENTIFIER'; +SET: 'SET'; +SNAPSHOT: 'SNAPSHOT'; +SNAPSHOTS: 'SNAPSHOTS'; +TABLE: 'TABLE'; +TAG: 'TAG'; +UNORDERED: 'UNORDERED'; +VERSION: 'VERSION'; +WITH: 'WITH'; +WRITE: 'WRITE'; + +TRUE: 'TRUE'; +FALSE: 'FALSE'; + +MAP: 'MAP'; +ARRAY: 'ARRAY'; + +PLUS: '+'; +MINUS: '-'; + +STRING + : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' + | '"' ( ~('"'|'\\') | ('\\' .) )* '"' + ; + +BIGINT_LITERAL + : DIGIT+ 'L' + ; + +SMALLINT_LITERAL + : DIGIT+ 'S' + ; + +TINYINT_LITERAL + : DIGIT+ 'Y' + ; + +INTEGER_VALUE + : DIGIT+ + ; + +EXPONENT_VALUE + : DIGIT+ EXPONENT + | DECIMAL_DIGITS EXPONENT {isValidDecimal()}? + ; + +DECIMAL_VALUE + : DECIMAL_DIGITS {isValidDecimal()}? + ; + +FLOAT_LITERAL + : DIGIT+ EXPONENT? 'F' + | DECIMAL_DIGITS EXPONENT? 'F' {isValidDecimal()}? + ; + +DOUBLE_LITERAL + : DIGIT+ EXPONENT? 'D' + | DECIMAL_DIGITS EXPONENT? 'D' {isValidDecimal()}? + ; + +BIGDECIMAL_LITERAL + : DIGIT+ EXPONENT? 'BD' + | DECIMAL_DIGITS EXPONENT? 'BD' {isValidDecimal()}? + ; + +IDENTIFIER + : (LETTER | DIGIT | '_')+ + ; + +BACKQUOTED_IDENTIFIER + : '`' ( ~'`' | '``' )* '`' + ; + +fragment DECIMAL_DIGITS + : DIGIT+ '.' DIGIT* + | '.' DIGIT+ + ; + +fragment EXPONENT + : 'E' [+-]? DIGIT+ + ; + +fragment DIGIT + : [0-9] + ; + +fragment LETTER + : [A-Z] + ; + +SIMPLE_COMMENT + : '--' ('\\\n' | ~[\r\n])* '\r'? '\n'? -> channel(HIDDEN) + ; + +BRACKETED_COMMENT + : '/*' {!isHint()}? (BRACKETED_COMMENT|.)*? '*/' -> channel(HIDDEN) + ; + +WS + : [ \r\n\t]+ -> channel(HIDDEN) + ; + +// Catch-all for anything we can't recognize. +// We use this to be able to ignore and recover all the text +// when splitting statements with DelimiterLexer +UNRECOGNIZED + : . + ; diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala new file mode 100644 index 000000000000..3fca29c294c0 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions + +import org.apache.spark.sql.SparkSessionExtensions +import org.apache.spark.sql.catalyst.analysis.CheckViews +import org.apache.spark.sql.catalyst.analysis.ProcedureArgumentCoercion +import org.apache.spark.sql.catalyst.analysis.ResolveProcedures +import org.apache.spark.sql.catalyst.analysis.ResolveViews +import org.apache.spark.sql.catalyst.optimizer.ReplaceStaticInvoke +import org.apache.spark.sql.catalyst.parser.extensions.IcebergSparkSqlExtensionsParser +import org.apache.spark.sql.execution.datasources.v2.ExtendedDataSourceV2Strategy + +class IcebergSparkSessionExtensions extends (SparkSessionExtensions => Unit) { + + override def apply(extensions: SparkSessionExtensions): Unit = { + // parser extensions + extensions.injectParser { case (_, parser) => new IcebergSparkSqlExtensionsParser(parser) } + + // analyzer extensions + extensions.injectResolutionRule { spark => ResolveProcedures(spark) } + extensions.injectResolutionRule { spark => ResolveViews(spark) } + extensions.injectResolutionRule { _ => ProcedureArgumentCoercion } + extensions.injectCheckRule(_ => CheckViews) + + // optimizer extensions + extensions.injectOptimizerRule { _ => ReplaceStaticInvoke } + + // planner extensions + extensions.injectPlannerStrategy { spark => ExtendedDataSourceV2Strategy(spark) } + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckViews.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckViews.scala new file mode 100644 index 000000000000..b559004b9466 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckViews.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression +import org.apache.spark.sql.catalyst.plans.logical.AlterViewAs +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias +import org.apache.spark.sql.catalyst.plans.logical.View +import org.apache.spark.sql.catalyst.plans.logical.views.CreateIcebergView +import org.apache.spark.sql.catalyst.plans.logical.views.ResolvedV2View +import org.apache.spark.sql.connector.catalog.ViewCatalog +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.util.SchemaUtils + +object CheckViews extends (LogicalPlan => Unit) { + + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override def apply(plan: LogicalPlan): Unit = { + plan foreach { + case CreateIcebergView(resolvedIdent@ResolvedIdentifier(_: ViewCatalog, _), _, query, columnAliases, _, + _, _, _, _, replace, _) => + verifyColumnCount(resolvedIdent, columnAliases, query) + SchemaUtils.checkColumnNameDuplication(query.schema.fieldNames, SQLConf.get.resolver) + if (replace) { + val viewIdent: Seq[String] = resolvedIdent.catalog.name() +: resolvedIdent.identifier.asMultipartIdentifier + checkCyclicViewReference(viewIdent, query, Seq(viewIdent)) + } + + case AlterViewAs(ResolvedV2View(_, _), _, _) => + throw new IcebergAnalysisException( + "ALTER VIEW AS is not supported. Use CREATE OR REPLACE VIEW instead") + case _ => // OK + } + } + + private def verifyColumnCount(ident: ResolvedIdentifier, columns: Seq[String], query: LogicalPlan): Unit = { + if (columns.nonEmpty) { + if (columns.length > query.output.length) { + throw new AnalysisException( + errorClass = "CREATE_VIEW_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", + messageParameters = Map( + "viewName" -> String.format("%s.%s", ident.catalog.name(), ident.identifier), + "viewColumns" -> columns.mkString(", "), + "dataColumns" -> query.output.map(c => c.name).mkString(", "))) + } else if (columns.length < query.output.length) { + throw new AnalysisException( + errorClass = "CREATE_VIEW_COLUMN_ARITY_MISMATCH.TOO_MANY_DATA_COLUMNS", + messageParameters = Map( + "viewName" -> String.format("%s.%s", ident.catalog.name(), ident.identifier), + "viewColumns" -> columns.mkString(", "), + "dataColumns" -> query.output.map(c => c.name).mkString(", "))) + } + } + } + + private def checkCyclicViewReference( + viewIdent: Seq[String], + plan: LogicalPlan, + cyclePath: Seq[Seq[String]]): Unit = { + plan match { + case sub@SubqueryAlias(_, Project(_, _)) => + val currentViewIdent: Seq[String] = sub.identifier.qualifier :+ sub.identifier.name + checkIfRecursiveView(viewIdent, currentViewIdent, cyclePath, sub.children) + case v1View: View => + val currentViewIdent: Seq[String] = v1View.desc.identifier.nameParts + checkIfRecursiveView(viewIdent, currentViewIdent, cyclePath, v1View.children) + case _ => + plan.children.foreach(child => checkCyclicViewReference(viewIdent, child, cyclePath)) + } + + plan.expressions.flatMap(_.flatMap { + case e: SubqueryExpression => + checkCyclicViewReference(viewIdent, e.plan, cyclePath) + None + case _ => None + }) + } + + private def checkIfRecursiveView( + viewIdent: Seq[String], + currentViewIdent: Seq[String], + cyclePath: Seq[Seq[String]], + children: Seq[LogicalPlan] + ): Unit = { + val newCyclePath = cyclePath :+ currentViewIdent + if (currentViewIdent == viewIdent) { + throw new IcebergAnalysisException(String.format("Recursive cycle in view detected: %s (cycle: %s)", + viewIdent.asIdentifier, newCyclePath.map(p => p.mkString(".")).mkString(" -> "))) + } else { + children.foreach { c => + checkCyclicViewReference(viewIdent, c, newCyclePath) + } + } + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ProcedureArgumentCoercion.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ProcedureArgumentCoercion.scala new file mode 100644 index 000000000000..01dbd6952618 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ProcedureArgumentCoercion.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.catalyst.plans.logical.IcebergCall +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule + +object ProcedureArgumentCoercion extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case c @ IcebergCall(procedure, args) if c.resolved => + val params = procedure.parameters + + val newArgs = args.zipWithIndex.map { case (arg, index) => + val param = params(index) + val paramType = param.dataType + val argType = arg.dataType + + if (paramType != argType && !Cast.canUpCast(argType, paramType)) { + throw new IcebergAnalysisException( + s"Wrong arg type for ${param.name}: cannot cast $argType to $paramType") + } + + if (paramType != argType) { + Cast(arg, paramType) + } else { + arg + } + } + + if (newArgs != args) { + c.copy(args = newArgs) + } else { + c + } + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveProcedures.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveProcedures.scala new file mode 100644 index 000000000000..2d02d4ce76e0 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveProcedures.scala @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import java.util.Locale +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.plans.logical.CallArgument +import org.apache.spark.sql.catalyst.plans.logical.CallStatement +import org.apache.spark.sql.catalyst.plans.logical.IcebergCall +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.NamedArgument +import org.apache.spark.sql.catalyst.plans.logical.PositionalArgument +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.CatalogManager +import org.apache.spark.sql.connector.catalog.CatalogPlugin +import org.apache.spark.sql.connector.catalog.LookupCatalog +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureCatalog +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter +import scala.collection.Seq + +case class ResolveProcedures(spark: SparkSession) extends Rule[LogicalPlan] with LookupCatalog { + + protected lazy val catalogManager: CatalogManager = spark.sessionState.catalogManager + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case CallStatement(CatalogAndIdentifier(catalog, ident), args) => + val procedure = catalog.asProcedureCatalog.loadProcedure(ident) + + val params = procedure.parameters + val normalizedParams = normalizeParams(params) + validateParams(normalizedParams) + + val normalizedArgs = normalizeArgs(args) + IcebergCall(procedure, args = buildArgExprs(normalizedParams, normalizedArgs).toSeq) + } + + private def validateParams(params: Seq[ProcedureParameter]): Unit = { + // should not be any duplicate param names + val duplicateParamNames = params.groupBy(_.name).collect { + case (name, matchingParams) if matchingParams.length > 1 => name + } + + if (duplicateParamNames.nonEmpty) { + throw new IcebergAnalysisException(s"Duplicate parameter names: ${duplicateParamNames.mkString("[", ",", "]")}") + } + + // optional params should be at the end + params.sliding(2).foreach { + case Seq(previousParam, currentParam) if !previousParam.required && currentParam.required => + throw new IcebergAnalysisException( + s"Optional parameters must be after required ones but $currentParam is after $previousParam") + case _ => + } + } + + private def buildArgExprs( + params: Seq[ProcedureParameter], + args: Seq[CallArgument]): Seq[Expression] = { + + // build a map of declared parameter names to their positions + val nameToPositionMap = params.map(_.name).zipWithIndex.toMap + + // build a map of parameter names to args + val nameToArgMap = buildNameToArgMap(params, args, nameToPositionMap) + + // verify all required parameters are provided + val missingParamNames = params.filter(_.required).collect { + case param if !nameToArgMap.contains(param.name) => param.name + } + + if (missingParamNames.nonEmpty) { + throw new IcebergAnalysisException(s"Missing required parameters: ${missingParamNames.mkString("[", ",", "]")}") + } + + val argExprs = new Array[Expression](params.size) + + nameToArgMap.foreach { case (name, arg) => + val position = nameToPositionMap(name) + argExprs(position) = arg.expr + } + + // assign nulls to optional params that were not set + params.foreach { + case p if !p.required && !nameToArgMap.contains(p.name) => + val position = nameToPositionMap(p.name) + argExprs(position) = Literal.create(null, p.dataType) + case _ => + } + + argExprs + } + + private def buildNameToArgMap( + params: Seq[ProcedureParameter], + args: Seq[CallArgument], + nameToPositionMap: Map[String, Int]): Map[String, CallArgument] = { + + val containsNamedArg = args.exists(_.isInstanceOf[NamedArgument]) + val containsPositionalArg = args.exists(_.isInstanceOf[PositionalArgument]) + + if (containsNamedArg && containsPositionalArg) { + throw new IcebergAnalysisException("Named and positional arguments cannot be mixed") + } + + if (containsNamedArg) { + buildNameToArgMapUsingNames(args, nameToPositionMap) + } else { + buildNameToArgMapUsingPositions(args, params) + } + } + + private def buildNameToArgMapUsingNames( + args: Seq[CallArgument], + nameToPositionMap: Map[String, Int]): Map[String, CallArgument] = { + + val namedArgs = args.asInstanceOf[Seq[NamedArgument]] + + val validationErrors = namedArgs.groupBy(_.name).collect { + case (name, matchingArgs) if matchingArgs.size > 1 => s"Duplicate procedure argument: $name" + case (name, _) if !nameToPositionMap.contains(name) => s"Unknown argument: $name" + } + + if (validationErrors.nonEmpty) { + throw new IcebergAnalysisException(s"Could not build name to arg map: ${validationErrors.mkString(", ")}") + } + + namedArgs.map(arg => arg.name -> arg).toMap + } + + private def buildNameToArgMapUsingPositions( + args: Seq[CallArgument], + params: Seq[ProcedureParameter]): Map[String, CallArgument] = { + + if (args.size > params.size) { + throw new IcebergAnalysisException("Too many arguments for procedure") + } + + args.zipWithIndex.map { case (arg, position) => + val param = params(position) + param.name -> arg + }.toMap + } + + private def normalizeParams(params: Seq[ProcedureParameter]): Seq[ProcedureParameter] = { + params.map { + case param if param.required => + val normalizedName = param.name.toLowerCase(Locale.ROOT) + ProcedureParameter.required(normalizedName, param.dataType) + case param => + val normalizedName = param.name.toLowerCase(Locale.ROOT) + ProcedureParameter.optional(normalizedName, param.dataType) + } + } + + private def normalizeArgs(args: Seq[CallArgument]): Seq[CallArgument] = { + args.map { + case a @ NamedArgument(name, _) => a.copy(name = name.toLowerCase(Locale.ROOT)) + case other => other + } + } + + implicit class CatalogHelper(plugin: CatalogPlugin) { + def asProcedureCatalog: ProcedureCatalog = plugin match { + case procedureCatalog: ProcedureCatalog => + procedureCatalog + case _ => + throw new IcebergAnalysisException(s"Cannot use catalog ${plugin.name}: not a ProcedureCatalog") + } + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveViews.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveViews.scala new file mode 100644 index 000000000000..397b70b188d4 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveViews.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.analysis.ViewUtil.IcebergViewHelper +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression +import org.apache.spark.sql.catalyst.expressions.UpCast +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias +import org.apache.spark.sql.catalyst.plans.logical.views.CreateIcebergView +import org.apache.spark.sql.catalyst.plans.logical.views.ResolvedV2View +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.CurrentOrigin +import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.connector.catalog.CatalogManager +import org.apache.spark.sql.connector.catalog.LookupCatalog +import org.apache.spark.sql.connector.catalog.View +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.types.MetadataBuilder + +case class ResolveViews(spark: SparkSession) extends Rule[LogicalPlan] with LookupCatalog { + + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + protected lazy val catalogManager: CatalogManager = spark.sessionState.catalogManager + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case u@UnresolvedRelation(nameParts, _, _) + if catalogManager.v1SessionCatalog.isTempView(nameParts) => + u + + case u@UnresolvedRelation(parts@CatalogAndIdentifier(catalog, ident), _, _) => + ViewUtil.loadView(catalog, ident) + .map(createViewRelation(parts, _)) + .getOrElse(u) + + case u@UnresolvedTableOrView(CatalogAndIdentifier(catalog, ident), _, _) => + ViewUtil.loadView(catalog, ident) + .map(_ => ResolvedV2View(catalog.asViewCatalog, ident)) + .getOrElse(u) + + case c@CreateIcebergView(ResolvedIdentifier(_, _), _, query, columnAliases, columnComments, _, _, _, _, _, _) + if query.resolved && !c.rewritten => + val aliased = aliasColumns(query, columnAliases, columnComments) + c.copy(query = aliased, queryColumnNames = query.schema.fieldNames, rewritten = true) + } + + private def aliasColumns( + plan: LogicalPlan, + columnAliases: Seq[String], + columnComments: Seq[Option[String]]): LogicalPlan = { + if (columnAliases.isEmpty || columnAliases.length != plan.output.length) { + plan + } else { + val projectList = plan.output.zipWithIndex.map { case (attr, pos) => + if (columnComments.apply(pos).isDefined) { + val meta = new MetadataBuilder().putString("comment", columnComments.apply(pos).get).build() + Alias(attr, columnAliases.apply(pos))(explicitMetadata = Some(meta)) + } else { + Alias(attr, columnAliases.apply(pos))() + } + } + Project(projectList, plan) + } + } + + + private def createViewRelation(nameParts: Seq[String], view: View): LogicalPlan = { + val parsed = parseViewText(nameParts.quoted, view.query) + + // Apply any necessary rewrites to preserve correct resolution + val viewCatalogAndNamespace: Seq[String] = view.currentCatalog +: view.currentNamespace.toSeq + val rewritten = rewriteIdentifiers(parsed, viewCatalogAndNamespace); + + // Apply the field aliases and column comments + // This logic differs from how Spark handles views in SessionCatalog.fromCatalogTable. + // This is more strict because it doesn't allow resolution by field name. + val aliases = view.schema.fields.zipWithIndex.map { case (expected, pos) => + val attr = GetColumnByOrdinal(pos, expected.dataType) + Alias(UpCast(attr, expected.dataType), expected.name)(explicitMetadata = Some(expected.metadata)) + } + + SubqueryAlias(nameParts, Project(aliases, rewritten)) + } + + private def parseViewText(name: String, viewText: String): LogicalPlan = { + val origin = Origin( + objectType = Some("VIEW"), + objectName = Some(name) + ) + + try { + CurrentOrigin.withOrigin(origin) { + spark.sessionState.sqlParser.parseQuery(viewText) + } + } catch { + case _: ParseException => + throw QueryCompilationErrors.invalidViewNameError(name); + } + } + + private def rewriteIdentifiers( + plan: LogicalPlan, + catalogAndNamespace: Seq[String]): LogicalPlan = { + // Substitute CTEs within the view, then rewrite unresolved functions and relations + qualifyTableIdentifiers( + qualifyFunctionIdentifiers( + CTESubstitution.apply(plan), + catalogAndNamespace), + catalogAndNamespace) + } + + private def qualifyFunctionIdentifiers( + plan: LogicalPlan, + catalogAndNamespace: Seq[String]): LogicalPlan = plan transformExpressions { + case u@UnresolvedFunction(Seq(name), _, _, _, _, _, _) => + if (!isBuiltinFunction(name)) { + u.copy(nameParts = catalogAndNamespace :+ name) + } else { + u + } + case u@UnresolvedFunction(parts, _, _, _, _, _, _) if !isCatalog(parts.head) => + u.copy(nameParts = catalogAndNamespace.head +: parts) + } + + /** + * Qualify table identifiers with default catalog and namespace if necessary. + */ + private def qualifyTableIdentifiers( + child: LogicalPlan, + catalogAndNamespace: Seq[String]): LogicalPlan = + child transform { + case u@UnresolvedRelation(Seq(table), _, _) => + u.copy(multipartIdentifier = catalogAndNamespace :+ table) + case u@UnresolvedRelation(parts, _, _) if !isCatalog(parts.head) => + u.copy(multipartIdentifier = catalogAndNamespace.head +: parts) + case other => + other.transformExpressions { + case subquery: SubqueryExpression => + subquery.withNewPlan(qualifyTableIdentifiers(subquery.plan, catalogAndNamespace)) + } + } + + private def isCatalog(name: String): Boolean = { + catalogManager.isCatalogRegistered(name) + } + + private def isBuiltinFunction(name: String): Boolean = { + catalogManager.v1SessionCatalog.isBuiltinFunction(FunctionIdentifier(name)) + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteViewCommands.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteViewCommands.scala new file mode 100644 index 000000000000..0546da1653d4 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteViewCommands.scala @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.analysis.ViewUtil.IcebergViewHelper +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression +import org.apache.spark.sql.catalyst.plans.logical.CreateView +import org.apache.spark.sql.catalyst.plans.logical.DropView +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.ShowViews +import org.apache.spark.sql.catalyst.plans.logical.View +import org.apache.spark.sql.catalyst.plans.logical.views.CreateIcebergView +import org.apache.spark.sql.catalyst.plans.logical.views.DropIcebergView +import org.apache.spark.sql.catalyst.plans.logical.views.ResolvedV2View +import org.apache.spark.sql.catalyst.plans.logical.views.ShowIcebergViews +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_FUNCTION +import org.apache.spark.sql.connector.catalog.CatalogManager +import org.apache.spark.sql.connector.catalog.LookupCatalog +import scala.collection.mutable + +/** + * ResolveSessionCatalog exits early for some v2 View commands, + * thus they are pre-substituted here and then handled in ResolveViews + */ +case class RewriteViewCommands(spark: SparkSession) extends Rule[LogicalPlan] with LookupCatalog { + + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + protected lazy val catalogManager: CatalogManager = spark.sessionState.catalogManager + + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + case DropView(ResolvedIdent(resolved), ifExists) => + DropIcebergView(resolved, ifExists) + + case CreateView(ResolvedIdent(resolved), userSpecifiedColumns, comment, properties, + Some(queryText), query, allowExisting, replace, _) => + val q = CTESubstitution.apply(query) + verifyTemporaryObjectsDontExist(resolved, q) + CreateIcebergView(child = resolved, + queryText = queryText, + query = q, + columnAliases = userSpecifiedColumns.map(_._1), + columnComments = userSpecifiedColumns.map(_._2.orElse(Option.empty)), + comment = comment, + properties = properties, + allowExisting = allowExisting, + replace = replace) + + case view @ ShowViews(CurrentNamespace, pattern, output) => + if (ViewUtil.isViewCatalog(catalogManager.currentCatalog)) { + ShowIcebergViews(ResolvedNamespace(catalogManager.currentCatalog, catalogManager.currentNamespace), + pattern, output) + } else { + view + } + + case ShowViews(UnresolvedNamespace(CatalogAndNamespace(catalog, ns), _), pattern, output) + if ViewUtil.isViewCatalog(catalog) => + ShowIcebergViews(ResolvedNamespace(catalog, ns), pattern, output) + + // needs to be done here instead of in ResolveViews, so that a V2 view can be resolved before the Analyzer + // tries to resolve it, which would result in an error, saying that V2 views aren't supported + case u@UnresolvedView(ResolvedView(resolved), _, _, _) => + ViewUtil.loadView(resolved.catalog, resolved.identifier) + .map(_ => ResolvedV2View(resolved.catalog.asViewCatalog, resolved.identifier)) + .getOrElse(u) + } + + private def isTempView(nameParts: Seq[String]): Boolean = { + catalogManager.v1SessionCatalog.isTempView(nameParts) + } + + private def isTempFunction(nameParts: Seq[String]): Boolean = { + if (nameParts.size > 1) { + return false + } + catalogManager.v1SessionCatalog.isTemporaryFunction(nameParts.asFunctionIdentifier) + } + + private object ResolvedIdent { + def unapply(unresolved: UnresolvedIdentifier): Option[ResolvedIdentifier] = unresolved match { + case UnresolvedIdentifier(nameParts, true) if isTempView(nameParts) => + None + + case UnresolvedIdentifier(CatalogAndIdentifier(catalog, ident), _) if ViewUtil.isViewCatalog(catalog) => + Some(ResolvedIdentifier(catalog, ident)) + + case _ => + None + } + } + + /** + * Permanent views are not allowed to reference temp objects + */ + private def verifyTemporaryObjectsDontExist( + identifier: ResolvedIdentifier, + child: LogicalPlan): Unit = { + val tempViews = collectTemporaryViews(child) + if (tempViews.nonEmpty) { + throw invalidRefToTempObject(identifier, tempViews.map(v => v.quoted).mkString("[", ", ", "]"), "view") + } + + val tempFunctions = collectTemporaryFunctions(child) + if (tempFunctions.nonEmpty) { + throw invalidRefToTempObject(identifier, tempFunctions.mkString("[", ", ", "]"), "function") + } + } + + private def invalidRefToTempObject(ident: ResolvedIdentifier, tempObjectNames: String, tempObjectType: String) = { + new IcebergAnalysisException(String.format("Cannot create view %s.%s that references temporary %s: %s", + ident.catalog.name(), ident.identifier, tempObjectType, tempObjectNames)) + } + + /** + * Collect all temporary views and return the identifiers separately + */ + private def collectTemporaryViews(child: LogicalPlan): Seq[Seq[String]] = { + def collectTempViews(child: LogicalPlan): Seq[Seq[String]] = { + child.flatMap { + case unresolved: UnresolvedRelation if isTempView(unresolved.multipartIdentifier) => + Seq(unresolved.multipartIdentifier) + case view: View if view.isTempView => Seq(view.desc.identifier.nameParts) + case plan => plan.expressions.flatMap(_.flatMap { + case e: SubqueryExpression => collectTempViews(e.plan) + case _ => Seq.empty + }) + }.distinct + } + + collectTempViews(child) + } + + private object ResolvedView { + def unapply(identifier: Seq[String]): Option[ResolvedV2View] = identifier match { + case nameParts if isTempView(nameParts) => + None + + case CatalogAndIdentifier(catalog, ident) if ViewUtil.isViewCatalog(catalog) => + ViewUtil.loadView(catalog, ident).flatMap(_ => Some(ResolvedV2View(catalog.asViewCatalog, ident))) + + case _ => + None + } + } + + /** + * Collect the names of all temporary functions. + */ + private def collectTemporaryFunctions(child: LogicalPlan): Seq[String] = { + val tempFunctions = new mutable.HashSet[String]() + child.resolveExpressionsWithPruning(_.containsAnyPattern(UNRESOLVED_FUNCTION)) { + case f @ UnresolvedFunction(nameParts, _, _, _, _, _, _) if isTempFunction(nameParts) => + tempFunctions += nameParts.head + f + case e: SubqueryExpression => + tempFunctions ++= collectTemporaryFunctions(e.plan) + e + } + tempFunctions.toSeq + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ViewUtil.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ViewUtil.scala new file mode 100644 index 000000000000..d46f10b7f5a2 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ViewUtil.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.connector.catalog.CatalogPlugin +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.View +import org.apache.spark.sql.connector.catalog.ViewCatalog +import org.apache.spark.sql.errors.QueryCompilationErrors + +object ViewUtil { + def loadView(catalog: CatalogPlugin, ident: Identifier): Option[View] = catalog match { + case viewCatalog: ViewCatalog => + try { + Option(viewCatalog.loadView(ident)) + } catch { + case _: NoSuchViewException => None + } + case _ => None + } + + def isViewCatalog(catalog: CatalogPlugin): Boolean = { + catalog.isInstanceOf[ViewCatalog] + } + + implicit class IcebergViewHelper(plugin: CatalogPlugin) { + def asViewCatalog: ViewCatalog = plugin match { + case viewCatalog: ViewCatalog => + viewCatalog + case _ => + throw QueryCompilationErrors.missingCatalogAbilityError(plugin, "views") + } + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceStaticInvoke.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceStaticInvoke.scala new file mode 100644 index 000000000000..d5c4cb84a02a --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceStaticInvoke.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.iceberg.spark.functions.SparkFunctions +import org.apache.spark.sql.catalyst.expressions.ApplyFunctionExpression +import org.apache.spark.sql.catalyst.expressions.BinaryComparison +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.In +import org.apache.spark.sql.catalyst.expressions.InSet +import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke +import org.apache.spark.sql.catalyst.plans.logical.Filter +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.ReplaceData +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.BINARY_COMPARISON +import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND +import org.apache.spark.sql.catalyst.trees.TreePattern.FILTER +import org.apache.spark.sql.catalyst.trees.TreePattern.IN +import org.apache.spark.sql.catalyst.trees.TreePattern.INSET +import org.apache.spark.sql.catalyst.trees.TreePattern.JOIN +import org.apache.spark.sql.connector.catalog.functions.ScalarFunction +import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.types.StructType + +/** + * Spark analyzes the Iceberg system function to {@link StaticInvoke} which could not be pushed + * down to datasource. This rule will replace {@link StaticInvoke} to + * {@link ApplyFunctionExpression} for Iceberg system function in a filter condition. + */ +object ReplaceStaticInvoke extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = + plan.transformWithPruning (_.containsAnyPattern(COMMAND, FILTER, JOIN)) { + case replace @ ReplaceData(_, cond, _, _, _, _) => + replaceStaticInvoke(replace, cond, newCond => replace.copy(condition = newCond)) + + case join @ Join(_, _, _, Some(cond), _) => + replaceStaticInvoke(join, cond, newCond => join.copy(condition = Some(newCond))) + + case filter @ Filter(cond, _) => + replaceStaticInvoke(filter, cond, newCond => filter.copy(condition = newCond)) + } + + private def replaceStaticInvoke[T <: LogicalPlan]( + node: T, + condition: Expression, + copy: Expression => T): T = { + val newCondition = replaceStaticInvoke(condition) + if (newCondition fastEquals condition) node else copy(newCondition) + } + + private def replaceStaticInvoke(condition: Expression): Expression = { + condition.transformWithPruning(_.containsAnyPattern(BINARY_COMPARISON, IN, INSET)) { + case in @ In(value: StaticInvoke, _) if canReplace(value) => + in.copy(value = replaceStaticInvoke(value)) + + case in @ InSet(value: StaticInvoke, _) if canReplace(value) => + in.copy(child = replaceStaticInvoke(value)) + + case c @ BinaryComparison(left: StaticInvoke, right) if canReplace(left) && right.foldable => + c.withNewChildren(Seq(replaceStaticInvoke(left), right)) + + case c @ BinaryComparison(left, right: StaticInvoke) if canReplace(right) && left.foldable => + c.withNewChildren(Seq(left, replaceStaticInvoke(right))) + } + } + + private def replaceStaticInvoke(invoke: StaticInvoke): Expression = { + // Adaptive from `resolveV2Function` in org.apache.spark.sql.catalyst.analysis.ResolveFunctions + val unbound = SparkFunctions.loadFunctionByClass(invoke.staticObject) + if (unbound == null) { + return invoke + } + + val inputType = StructType(invoke.arguments.zipWithIndex.map { + case (exp, pos) => StructField(s"_$pos", exp.dataType, exp.nullable) + }) + + val bound = try { + unbound.bind(inputType) + } catch { + case _: Exception => + return invoke + } + + if (bound.inputTypes().length != invoke.arguments.length) { + return invoke + } + + bound match { + case scalarFunc: ScalarFunction[_] => + ApplyFunctionExpression(scalarFunc, invoke.arguments) + case _ => invoke + } + } + + @inline + private def canReplace(invoke: StaticInvoke): Boolean = { + invoke.functionName == ScalarFunction.MAGIC_METHOD_NAME && !invoke.foldable + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala new file mode 100644 index 000000000000..554a06d9610b --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala @@ -0,0 +1,346 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.parser.extensions + +import java.util.Locale +import org.antlr.v4.runtime._ +import org.antlr.v4.runtime.atn.PredictionMode +import org.antlr.v4.runtime.misc.Interval +import org.antlr.v4.runtime.misc.ParseCancellationException +import org.antlr.v4.runtime.tree.TerminalNodeImpl +import org.apache.iceberg.common.DynConstructors +import org.apache.iceberg.spark.ExtendedParser +import org.apache.iceberg.spark.ExtendedParser.RawOrderField +import org.apache.iceberg.spark.procedures.SparkProcedures +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.RewriteViewCommands +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.parser.CompoundBody +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.parser.extensions.IcebergSqlExtensionsParser.NonReservedContext +import org.apache.spark.sql.catalyst.parser.extensions.IcebergSqlExtensionsParser.QuotedIdentifierContext +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.VariableSubstitution +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.StructType +import scala.jdk.CollectionConverters._ + +class IcebergSparkSqlExtensionsParser(delegate: ParserInterface) extends ParserInterface with ExtendedParser { + + import IcebergSparkSqlExtensionsParser._ + + private lazy val substitutor = substitutorCtor.newInstance(SQLConf.get) + private lazy val astBuilder = new IcebergSqlExtensionsAstBuilder(delegate) + + /** + * Parse a string to a DataType. + */ + override def parseDataType(sqlText: String): DataType = { + delegate.parseDataType(sqlText) + } + + /** + * Parse a string to a raw DataType without CHAR/VARCHAR replacement. + */ + def parseRawDataType(sqlText: String): DataType = throw new UnsupportedOperationException() + + /** + * Parse a string to an Expression. + */ + override def parseExpression(sqlText: String): Expression = { + delegate.parseExpression(sqlText) + } + + /** + * Parse a string to a TableIdentifier. + */ + override def parseTableIdentifier(sqlText: String): TableIdentifier = { + delegate.parseTableIdentifier(sqlText) + } + + /** + * Parse a string to a FunctionIdentifier. + */ + override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = { + delegate.parseFunctionIdentifier(sqlText) + } + + /** + * Parse a string to a multi-part identifier. + */ + override def parseMultipartIdentifier(sqlText: String): Seq[String] = { + delegate.parseMultipartIdentifier(sqlText) + } + + /** + * Creates StructType for a given SQL string, which is a comma separated list of field + * definitions which will preserve the correct Hive metadata. + */ + override def parseTableSchema(sqlText: String): StructType = { + delegate.parseTableSchema(sqlText) + } + + override def parseScript(sqlScriptText: String): CompoundBody = { + delegate.parseScript(sqlScriptText) + } + + override def parseSortOrder(sqlText: String): java.util.List[RawOrderField] = { + val fields = parse(sqlText) { parser => astBuilder.visitSingleOrder(parser.singleOrder()) } + fields.map { field => + val (term, direction, order) = field + new RawOrderField(term, direction, order) + }.asJava + } + + /** + * Parse a string to a LogicalPlan. + */ + override def parsePlan(sqlText: String): LogicalPlan = { + val sqlTextAfterSubstitution = substitutor.substitute(sqlText) + if (isIcebergCommand(sqlTextAfterSubstitution)) { + parse(sqlTextAfterSubstitution) { parser => astBuilder.visit(parser.singleStatement()) }.asInstanceOf[LogicalPlan] + } else { + RewriteViewCommands(SparkSession.active).apply(delegate.parsePlan(sqlText)) + } + } + + private def isIcebergCommand(sqlText: String): Boolean = { + val normalized = sqlText.toLowerCase(Locale.ROOT).trim() + // Strip simple SQL comments that terminate a line, e.g. comments starting with `--` . + .replaceAll("--.*?\\n", " ") + // Strip newlines. + .replaceAll("\\s+", " ") + // Strip comments of the form /* ... */. This must come after stripping newlines so that + // comments that span multiple lines are caught. + .replaceAll("/\\*.*?\\*/", " ") + // Strip backtick then `system`.`ancestors_of` changes to system.ancestors_of + .replaceAll("`", "") + .trim() + + isIcebergProcedure(normalized) || ( + normalized.startsWith("alter table") && ( + normalized.contains("add partition field") || + normalized.contains("drop partition field") || + normalized.contains("replace partition field") || + normalized.contains("write ordered by") || + normalized.contains("write locally ordered by") || + normalized.contains("write distributed by") || + normalized.contains("write unordered") || + normalized.contains("set identifier fields") || + normalized.contains("drop identifier fields") || + isSnapshotRefDdl(normalized))) + } + + // All builtin Iceberg procedures are under the 'system' namespace + private def isIcebergProcedure(normalized: String): Boolean = { + normalized.startsWith("call") && + SparkProcedures.names().asScala.map("system." + _).exists(normalized.contains) + } + + private def isSnapshotRefDdl(normalized: String): Boolean = { + normalized.contains("create branch") || + normalized.contains("replace branch") || + normalized.contains("create tag") || + normalized.contains("replace tag") || + normalized.contains("drop branch") || + normalized.contains("drop tag") + } + + protected def parse[T](command: String)(toResult: IcebergSqlExtensionsParser => T): T = { + val lexer = new IcebergSqlExtensionsLexer(new UpperCaseCharStream(CharStreams.fromString(command))) + lexer.removeErrorListeners() + lexer.addErrorListener(IcebergParseErrorListener) + + val tokenStream = new CommonTokenStream(lexer) + val parser = new IcebergSqlExtensionsParser(tokenStream) + parser.addParseListener(IcebergSqlExtensionsPostProcessor) + parser.removeErrorListeners() + parser.addErrorListener(IcebergParseErrorListener) + + try { + try { + // first, try parsing with potentially faster SLL mode + parser.getInterpreter.setPredictionMode(PredictionMode.SLL) + toResult(parser) + } + catch { + case _: ParseCancellationException => + // if we fail, parse with LL mode + tokenStream.seek(0) // rewind input stream + parser.reset() + + // Try Again. + parser.getInterpreter.setPredictionMode(PredictionMode.LL) + toResult(parser) + } + } + catch { + case e: IcebergParseException if e.command.isDefined => + throw e + case e: IcebergParseException => + throw e.withCommand(command) + case e: AnalysisException => + val position = Origin(e.line, e.startPosition) + throw new IcebergParseException(Option(command), e.message, position, position) + } + } + + override def parseQuery(sqlText: String): LogicalPlan = { + parsePlan(sqlText) + } +} + +object IcebergSparkSqlExtensionsParser { + private val substitutorCtor: DynConstructors.Ctor[VariableSubstitution] = + DynConstructors.builder() + .impl(classOf[VariableSubstitution]) + .impl(classOf[VariableSubstitution], classOf[SQLConf]) + .build() +} + +/* Copied from Apache Spark's to avoid dependency on Spark Internals */ +class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream { + override def consume(): Unit = wrapped.consume + override def getSourceName(): String = wrapped.getSourceName + override def index(): Int = wrapped.index + override def mark(): Int = wrapped.mark + override def release(marker: Int): Unit = wrapped.release(marker) + override def seek(where: Int): Unit = wrapped.seek(where) + override def size(): Int = wrapped.size + + override def getText(interval: Interval): String = wrapped.getText(interval) + + // scalastyle:off + override def LA(i: Int): Int = { + val la = wrapped.LA(i) + if (la == 0 || la == IntStream.EOF) la + else Character.toUpperCase(la) + } + // scalastyle:on +} + +/** + * The post-processor validates & cleans-up the parse tree during the parse process. + */ +case object IcebergSqlExtensionsPostProcessor extends IcebergSqlExtensionsBaseListener { + + /** Remove the back ticks from an Identifier. */ + override def exitQuotedIdentifier(ctx: QuotedIdentifierContext): Unit = { + replaceTokenByIdentifier(ctx, 1) { token => + // Remove the double back ticks in the string. + token.setText(token.getText.replace("``", "`")) + token + } + } + + /** Treat non-reserved keywords as Identifiers. */ + override def exitNonReserved(ctx: NonReservedContext): Unit = { + replaceTokenByIdentifier(ctx, 0)(identity) + } + + private def replaceTokenByIdentifier( + ctx: ParserRuleContext, + stripMargins: Int)( + f: CommonToken => CommonToken = identity): Unit = { + val parent = ctx.getParent + parent.removeLastChild() + val token = ctx.getChild(0).getPayload.asInstanceOf[Token] + val newToken = new CommonToken( + new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream), + IcebergSqlExtensionsParser.IDENTIFIER, + token.getChannel, + token.getStartIndex + stripMargins, + token.getStopIndex - stripMargins) + parent.addChild(new TerminalNodeImpl(f(newToken))) + } +} + +/* Partially copied from Apache Spark's Parser to avoid dependency on Spark Internals */ +case object IcebergParseErrorListener extends BaseErrorListener { + override def syntaxError( + recognizer: Recognizer[_, _], + offendingSymbol: scala.Any, + line: Int, + charPositionInLine: Int, + msg: String, + e: RecognitionException): Unit = { + val (start, stop) = offendingSymbol match { + case token: CommonToken => + val start = Origin(Some(line), Some(token.getCharPositionInLine)) + val length = token.getStopIndex - token.getStartIndex + 1 + val stop = Origin(Some(line), Some(token.getCharPositionInLine + length)) + (start, stop) + case _ => + val start = Origin(Some(line), Some(charPositionInLine)) + (start, start) + } + throw new IcebergParseException(None, msg, start, stop) + } +} + +/** + * Copied from Apache Spark + * A [[ParseException]] is an [[AnalysisException]] that is thrown during the parse process. It + * contains fields and an extended error message that make reporting and diagnosing errors easier. + */ +class IcebergParseException( + val command: Option[String], + message: String, + val start: Origin, + val stop: Origin) extends AnalysisException(message, start.line, start.startPosition) { + + def this(message: String, ctx: ParserRuleContext) = { + this(Option(IcebergParserUtils.command(ctx)), + message, + IcebergParserUtils.position(ctx.getStart), + IcebergParserUtils.position(ctx.getStop)) + } + + override def getMessage: String = { + val builder = new StringBuilder + builder ++= "\n" ++= message + start match { + case Origin(Some(l), Some(p), Some(startIndex), Some(stopIndex), Some(sqlText), + Some(objectType), Some(objectName), _, _) => + builder ++= s"(line $l, pos $p)\n" + command.foreach { cmd => + val (above, below) = cmd.split("\n").splitAt(l) + builder ++= "\n== SQL ==\n" + above.foreach(builder ++= _ += '\n') + builder ++= (0 until p).map(_ => "-").mkString("") ++= "^^^\n" + below.foreach(builder ++= _ += '\n') + } + case _ => + command.foreach { cmd => + builder ++= "\n== SQL ==\n" ++= cmd + } + } + builder.toString + } + + def withCommand(cmd: String): IcebergParseException = { + new IcebergParseException(Option(cmd), message, start, stop) + } +} \ No newline at end of file diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSqlExtensionsAstBuilder.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSqlExtensionsAstBuilder.scala new file mode 100644 index 000000000000..b95fc7755fb9 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSqlExtensionsAstBuilder.scala @@ -0,0 +1,387 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.parser.extensions + +import java.util.Locale +import java.util.concurrent.TimeUnit +import org.antlr.v4.runtime._ +import org.antlr.v4.runtime.misc.Interval +import org.antlr.v4.runtime.tree.ParseTree +import org.antlr.v4.runtime.tree.TerminalNode +import org.apache.iceberg.DistributionMode +import org.apache.iceberg.NullOrder +import org.apache.iceberg.SortDirection +import org.apache.iceberg.expressions.Term +import org.apache.iceberg.spark.Spark3Util +import org.apache.spark.sql.catalyst.analysis.IcebergAnalysisException +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.parser.extensions.IcebergParserUtils.withOrigin +import org.apache.spark.sql.catalyst.parser.extensions.IcebergSqlExtensionsParser._ +import org.apache.spark.sql.catalyst.plans.logical.AddPartitionField +import org.apache.spark.sql.catalyst.plans.logical.BranchOptions +import org.apache.spark.sql.catalyst.plans.logical.CallArgument +import org.apache.spark.sql.catalyst.plans.logical.CallStatement +import org.apache.spark.sql.catalyst.plans.logical.CreateOrReplaceBranch +import org.apache.spark.sql.catalyst.plans.logical.CreateOrReplaceTag +import org.apache.spark.sql.catalyst.plans.logical.DropBranch +import org.apache.spark.sql.catalyst.plans.logical.DropIdentifierFields +import org.apache.spark.sql.catalyst.plans.logical.DropPartitionField +import org.apache.spark.sql.catalyst.plans.logical.DropTag +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.NamedArgument +import org.apache.spark.sql.catalyst.plans.logical.PositionalArgument +import org.apache.spark.sql.catalyst.plans.logical.ReplacePartitionField +import org.apache.spark.sql.catalyst.plans.logical.SetIdentifierFields +import org.apache.spark.sql.catalyst.plans.logical.SetWriteDistributionAndOrdering +import org.apache.spark.sql.catalyst.plans.logical.TagOptions +import org.apache.spark.sql.catalyst.trees.CurrentOrigin +import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.connector.expressions +import org.apache.spark.sql.connector.expressions.ApplyTransform +import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.IdentityTransform +import org.apache.spark.sql.connector.expressions.LiteralValue +import org.apache.spark.sql.connector.expressions.Transform +import scala.jdk.CollectionConverters._ + +class IcebergSqlExtensionsAstBuilder(delegate: ParserInterface) extends IcebergSqlExtensionsBaseVisitor[AnyRef] { + + private def toBuffer[T](list: java.util.List[T]): scala.collection.mutable.Buffer[T] = list.asScala + private def toSeq[T](list: java.util.List[T]): Seq[T] = toBuffer(list).toSeq + + /** + * Create a [[CallStatement]] for a stored procedure call. + */ + override def visitCall(ctx: CallContext): CallStatement = withOrigin(ctx) { + val name = toSeq(ctx.multipartIdentifier.parts).map(_.getText) + val args = toSeq(ctx.callArgument).map(typedVisit[CallArgument]) + CallStatement(name, args) + } + + /** + * Create an ADD PARTITION FIELD logical command. + */ + override def visitAddPartitionField(ctx: AddPartitionFieldContext): AddPartitionField = withOrigin(ctx) { + AddPartitionField( + typedVisit[Seq[String]](ctx.multipartIdentifier), + typedVisit[Transform](ctx.transform), + Option(ctx.name).map(_.getText)) + } + + /** + * Create a DROP PARTITION FIELD logical command. + */ + override def visitDropPartitionField(ctx: DropPartitionFieldContext): DropPartitionField = withOrigin(ctx) { + DropPartitionField( + typedVisit[Seq[String]](ctx.multipartIdentifier), + typedVisit[Transform](ctx.transform)) + } + + /** + * Create a CREATE OR REPLACE BRANCH logical command. + */ + override def visitCreateOrReplaceBranch(ctx: CreateOrReplaceBranchContext): CreateOrReplaceBranch = withOrigin(ctx) { + val createOrReplaceBranchClause = ctx.createReplaceBranchClause() + + val branchName = createOrReplaceBranchClause.identifier() + val branchOptionsContext = Option(createOrReplaceBranchClause.branchOptions()) + val snapshotId = branchOptionsContext.flatMap(branchOptions => Option(branchOptions.snapshotId())) + .map(_.getText.toLong) + val snapshotRetention = branchOptionsContext.flatMap(branchOptions => Option(branchOptions.snapshotRetention())) + val minSnapshotsToKeep = snapshotRetention.flatMap(retention => Option(retention.minSnapshotsToKeep())) + .map(minSnapshots => minSnapshots.number().getText.toLong) + val maxSnapshotAgeMs = snapshotRetention + .flatMap(retention => Option(retention.maxSnapshotAge())) + .map(retention => TimeUnit.valueOf(retention.timeUnit().getText.toUpperCase(Locale.ENGLISH)) + .toMillis(retention.number().getText.toLong)) + val branchRetention = branchOptionsContext.flatMap(branchOptions => Option(branchOptions.refRetain())) + val branchRefAgeMs = branchRetention.map(retain => + TimeUnit.valueOf(retain.timeUnit().getText.toUpperCase(Locale.ENGLISH)).toMillis(retain.number().getText.toLong)) + val create = createOrReplaceBranchClause.CREATE() != null + val replace = ctx.createReplaceBranchClause().REPLACE() != null + val ifNotExists = createOrReplaceBranchClause.EXISTS() != null + + val branchOptions = BranchOptions( + snapshotId, + minSnapshotsToKeep, + maxSnapshotAgeMs, + branchRefAgeMs + ) + + CreateOrReplaceBranch( + typedVisit[Seq[String]](ctx.multipartIdentifier), + branchName.getText, + branchOptions, + create, + replace, + ifNotExists) + } + + /** + * Create an CREATE OR REPLACE TAG logical command. + */ + override def visitCreateOrReplaceTag(ctx: CreateOrReplaceTagContext): CreateOrReplaceTag = withOrigin(ctx) { + val createTagClause = ctx.createReplaceTagClause() + + val tagName = createTagClause.identifier().getText + + val tagOptionsContext = Option(createTagClause.tagOptions()) + val snapshotId = tagOptionsContext.flatMap(tagOptions => Option(tagOptions.snapshotId())) + .map(_.getText.toLong) + val tagRetain = tagOptionsContext.flatMap(tagOptions => Option(tagOptions.refRetain())) + val tagRefAgeMs = tagRetain.map(retain => + TimeUnit.valueOf(retain.timeUnit().getText.toUpperCase(Locale.ENGLISH)).toMillis(retain.number().getText.toLong)) + val tagOptions = TagOptions( + snapshotId, + tagRefAgeMs + ) + + val create = createTagClause.CREATE() != null + val replace = createTagClause.REPLACE() != null + val ifNotExists = createTagClause.EXISTS() != null + + CreateOrReplaceTag(typedVisit[Seq[String]](ctx.multipartIdentifier), + tagName, + tagOptions, + create, + replace, + ifNotExists) + } + + /** + * Create an DROP BRANCH logical command. + */ + override def visitDropBranch(ctx: DropBranchContext): DropBranch = withOrigin(ctx) { + DropBranch(typedVisit[Seq[String]](ctx.multipartIdentifier), ctx.identifier().getText, ctx.EXISTS() != null) + } + + /** + * Create an DROP TAG logical command. + */ + override def visitDropTag(ctx: DropTagContext): DropTag = withOrigin(ctx) { + DropTag(typedVisit[Seq[String]](ctx.multipartIdentifier), ctx.identifier().getText, ctx.EXISTS() != null) + } + + /** + * Create an REPLACE PARTITION FIELD logical command. + */ + override def visitReplacePartitionField(ctx: ReplacePartitionFieldContext): ReplacePartitionField = withOrigin(ctx) { + ReplacePartitionField( + typedVisit[Seq[String]](ctx.multipartIdentifier), + typedVisit[Transform](ctx.transform(0)), + typedVisit[Transform](ctx.transform(1)), + Option(ctx.name).map(_.getText)) + } + + /** + * Create an SET IDENTIFIER FIELDS logical command. + */ + override def visitSetIdentifierFields(ctx: SetIdentifierFieldsContext): SetIdentifierFields = withOrigin(ctx) { + SetIdentifierFields( + typedVisit[Seq[String]](ctx.multipartIdentifier), + toSeq(ctx.fieldList.fields).map(_.getText)) + } + + /** + * Create an DROP IDENTIFIER FIELDS logical command. + */ + override def visitDropIdentifierFields(ctx: DropIdentifierFieldsContext): DropIdentifierFields = withOrigin(ctx) { + DropIdentifierFields( + typedVisit[Seq[String]](ctx.multipartIdentifier), + toSeq(ctx.fieldList.fields).map(_.getText)) + } + + /** + * Create a [[SetWriteDistributionAndOrdering]] for changing the write distribution and ordering. + */ + override def visitSetWriteDistributionAndOrdering( + ctx: SetWriteDistributionAndOrderingContext): SetWriteDistributionAndOrdering = { + + val tableName = typedVisit[Seq[String]](ctx.multipartIdentifier) + + val (distributionSpec, orderingSpec) = toDistributionAndOrderingSpec(ctx.writeSpec) + + if (distributionSpec == null && orderingSpec == null) { + throw new IcebergAnalysisException( + "ALTER TABLE has no changes: missing both distribution and ordering clauses") + } + + val distributionMode = if (distributionSpec != null) { + Some(DistributionMode.HASH) + } else if (orderingSpec.UNORDERED != null) { + Some(DistributionMode.NONE) + } else if (orderingSpec.LOCALLY() != null) { + None + } else { + Some(DistributionMode.RANGE) + } + + val ordering = if (orderingSpec != null && orderingSpec.order != null) { + toSeq(orderingSpec.order.fields).map(typedVisit[(Term, SortDirection, NullOrder)]) + } else { + Seq.empty + } + + SetWriteDistributionAndOrdering(tableName, distributionMode, ordering) + } + + private def toDistributionAndOrderingSpec( + writeSpec: WriteSpecContext): (WriteDistributionSpecContext, WriteOrderingSpecContext) = { + + if (writeSpec.writeDistributionSpec.size > 1) { + throw new IcebergAnalysisException("ALTER TABLE contains multiple distribution clauses") + } + + if (writeSpec.writeOrderingSpec.size > 1) { + throw new IcebergAnalysisException("ALTER TABLE contains multiple ordering clauses") + } + + val distributionSpec = toBuffer(writeSpec.writeDistributionSpec).headOption.orNull + val orderingSpec = toBuffer(writeSpec.writeOrderingSpec).headOption.orNull + + (distributionSpec, orderingSpec) + } + + /** + * Create an order field. + */ + override def visitOrderField(ctx: OrderFieldContext): (Term, SortDirection, NullOrder) = { + val term = Spark3Util.toIcebergTerm(typedVisit[Transform](ctx.transform)) + val direction = Option(ctx.ASC).map(_ => SortDirection.ASC) + .orElse(Option(ctx.DESC).map(_ => SortDirection.DESC)) + .getOrElse(SortDirection.ASC) + val nullOrder = Option(ctx.FIRST).map(_ => NullOrder.NULLS_FIRST) + .orElse(Option(ctx.LAST).map(_ => NullOrder.NULLS_LAST)) + .getOrElse(if (direction == SortDirection.ASC) NullOrder.NULLS_FIRST else NullOrder.NULLS_LAST) + (term, direction, nullOrder) + } + + /** + * Create an IdentityTransform for a column reference. + */ + override def visitIdentityTransform(ctx: IdentityTransformContext): Transform = withOrigin(ctx) { + IdentityTransform(FieldReference(typedVisit[Seq[String]](ctx.multipartIdentifier()))) + } + + /** + * Create a named Transform from argument expressions. + */ + override def visitApplyTransform(ctx: ApplyTransformContext): Transform = withOrigin(ctx) { + val args = toSeq(ctx.arguments).map(typedVisit[expressions.Expression]) + ApplyTransform(ctx.transformName.getText, args) + } + + /** + * Create a transform argument from a column reference or a constant. + */ + override def visitTransformArgument(ctx: TransformArgumentContext): expressions.Expression = withOrigin(ctx) { + val reference = Option(ctx.multipartIdentifier()) + .map(typedVisit[Seq[String]]) + .map(FieldReference(_)) + val literal = Option(ctx.constant) + .map(visitConstant) + .map(lit => LiteralValue(lit.value, lit.dataType)) + reference.orElse(literal) + .getOrElse(throw new IcebergParseException(s"Invalid transform argument", ctx)) + } + + /** + * Return a multi-part identifier as Seq[String]. + */ + override def visitMultipartIdentifier(ctx: MultipartIdentifierContext): Seq[String] = withOrigin(ctx) { + toSeq(ctx.parts).map(_.getText) + } + + override def visitSingleOrder(ctx: SingleOrderContext): Seq[(Term, SortDirection, NullOrder)] = withOrigin(ctx) { + toSeq(ctx.order.fields).map(typedVisit[(Term, SortDirection, NullOrder)]) + } + + /** + * Create a positional argument in a stored procedure call. + */ + override def visitPositionalArgument(ctx: PositionalArgumentContext): CallArgument = withOrigin(ctx) { + val expr = typedVisit[Expression](ctx.expression) + PositionalArgument(expr) + } + + /** + * Create a named argument in a stored procedure call. + */ + override def visitNamedArgument(ctx: NamedArgumentContext): CallArgument = withOrigin(ctx) { + val name = ctx.identifier.getText + val expr = typedVisit[Expression](ctx.expression) + NamedArgument(name, expr) + } + + override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) { + visit(ctx.statement).asInstanceOf[LogicalPlan] + } + + def visitConstant(ctx: ConstantContext): Literal = { + delegate.parseExpression(ctx.getText).asInstanceOf[Literal] + } + + override def visitExpression(ctx: ExpressionContext): Expression = { + // reconstruct the SQL string and parse it using the main Spark parser + // while we can avoid the logic to build Spark expressions, we still have to parse them + // we cannot call ctx.getText directly since it will not render spaces correctly + // that's why we need to recurse down the tree in reconstructSqlString + val sqlString = reconstructSqlString(ctx) + delegate.parseExpression(sqlString) + } + + private def reconstructSqlString(ctx: ParserRuleContext): String = { + toBuffer(ctx.children).map { + case c: ParserRuleContext => reconstructSqlString(c) + case t: TerminalNode => t.getText + }.mkString(" ") + } + + private def typedVisit[T](ctx: ParseTree): T = { + ctx.accept(this).asInstanceOf[T] + } +} + +/* Partially copied from Apache Spark's Parser to avoid dependency on Spark Internals */ +object IcebergParserUtils { + + private[sql] def withOrigin[T](ctx: ParserRuleContext)(f: => T): T = { + val current = CurrentOrigin.get + CurrentOrigin.set(position(ctx.getStart)) + try { + f + } finally { + CurrentOrigin.set(current) + } + } + + private[sql] def position(token: Token): Origin = { + val opt = Option(token) + Origin(opt.map(_.getLine), opt.map(_.getCharPositionInLine)) + } + + /** Get the command which created the token. */ + private[sql] def command(ctx: ParserRuleContext): String = { + val stream = ctx.getStart.getInputStream + stream.getText(Interval.of(0, stream.size() - 1)) + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AddPartitionField.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AddPartitionField.scala new file mode 100644 index 000000000000..e8b1b2941161 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AddPartitionField.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.expressions.Transform + +case class AddPartitionField(table: Seq[String], transform: Transform, name: Option[String]) extends LeafCommand { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override def simpleString(maxFields: Int): String = { + s"AddPartitionField ${table.quoted} ${name.map(n => s"$n=").getOrElse("")}${transform.describe}" + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BranchOptions.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BranchOptions.scala new file mode 100644 index 000000000000..4d7e0a086bda --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BranchOptions.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +case class BranchOptions (snapshotId: Option[Long], numSnapshots: Option[Long], + snapshotRetain: Option[Long], snapshotRefRetain: Option[Long]) diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Call.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Call.scala new file mode 100644 index 000000000000..9e3fdb0e9e0e --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Call.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.connector.iceberg.catalog.Procedure + +case class Call(procedure: Procedure, args: Seq[Expression]) extends LeafCommand { + override lazy val output: Seq[Attribute] = DataTypeUtils.toAttributes(procedure.outputType) + + override def simpleString(maxFields: Int): String = { + s"Call${truncatedString(output.toSeq, "[", ", ", "]", maxFields)} ${procedure.description}" + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateOrReplaceBranch.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateOrReplaceBranch.scala new file mode 100644 index 000000000000..b7981a3c7a0d --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateOrReplaceBranch.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class CreateOrReplaceBranch( + table: Seq[String], + branch: String, + branchOptions: BranchOptions, + create: Boolean, + replace: Boolean, + ifNotExists: Boolean) extends LeafCommand { + + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override def simpleString(maxFields: Int): String = { + s"CreateOrReplaceBranch branch: ${branch} for table: ${table.quoted}" + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateOrReplaceTag.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateOrReplaceTag.scala new file mode 100644 index 000000000000..6e7db84a90fb --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateOrReplaceTag.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class CreateOrReplaceTag( + table: Seq[String], + tag: String, + tagOptions: TagOptions, + create: Boolean, + replace: Boolean, + ifNotExists: Boolean) extends LeafCommand { + + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override def simpleString(maxFields: Int): String = { + s"CreateOrReplaceTag tag: ${tag} for table: ${table.quoted}" + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropBranch.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropBranch.scala new file mode 100644 index 000000000000..bee0b0fae688 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropBranch.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class DropBranch(table: Seq[String], branch: String, ifExists: Boolean) extends LeafCommand { + + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override def simpleString(maxFields: Int): String = { + s"DropBranch branch: ${branch} for table: ${table.quoted}" + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropIdentifierFields.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropIdentifierFields.scala new file mode 100644 index 000000000000..29dd686a0fba --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropIdentifierFields.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class DropIdentifierFields( + table: Seq[String], + fields: Seq[String]) extends LeafCommand { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override def simpleString(maxFields: Int): String = { + s"DropIdentifierFields ${table.quoted} (${fields.quoted})" + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropPartitionField.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropPartitionField.scala new file mode 100644 index 000000000000..fb1451324182 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropPartitionField.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.expressions.Transform + +case class DropPartitionField(table: Seq[String], transform: Transform) extends LeafCommand { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override def simpleString(maxFields: Int): String = { + s"DropPartitionField ${table.quoted} ${transform.describe}" + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropTag.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropTag.scala new file mode 100644 index 000000000000..7e4b38e74d2f --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropTag.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class DropTag(table: Seq[String], tag: String, ifExists: Boolean) extends LeafCommand { + + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override def simpleString(maxFields: Int): String = { + s"DropTag tag: ${tag} for table: ${table.quoted}" + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/IcebergCall.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/IcebergCall.scala new file mode 100644 index 000000000000..032eac13970a --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/IcebergCall.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.connector.iceberg.catalog.Procedure + +case class IcebergCall(procedure: Procedure, args: Seq[Expression]) extends LeafCommand { + override lazy val output: Seq[Attribute] = DataTypeUtils.toAttributes(procedure.outputType) + + override def simpleString(maxFields: Int): String = { + s"IcebergCall${truncatedString(output.toSeq, "[", ", ", "]", maxFields)} ${procedure.description}" + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ReplacePartitionField.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ReplacePartitionField.scala new file mode 100644 index 000000000000..8c660c6f37b1 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ReplacePartitionField.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.expressions.Transform + +case class ReplacePartitionField( + table: Seq[String], + transformFrom: Transform, + transformTo: Transform, + name: Option[String]) extends LeafCommand { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override def simpleString(maxFields: Int): String = { + s"ReplacePartitionField ${table.quoted} ${transformFrom.describe} " + + s"with ${name.map(n => s"$n=").getOrElse("")}${transformTo.describe}" + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SetIdentifierFields.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SetIdentifierFields.scala new file mode 100644 index 000000000000..a5fa28a617e7 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SetIdentifierFields.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.expressions.Transform + +case class SetIdentifierFields( + table: Seq[String], + fields: Seq[String]) extends LeafCommand { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override def simpleString(maxFields: Int): String = { + s"SetIdentifierFields ${table.quoted} (${fields.quoted})" + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TagOptions.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TagOptions.scala new file mode 100644 index 000000000000..85e3b95f4aba --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TagOptions.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +case class TagOptions(snapshotId: Option[Long], snapshotRefRetain: Option[Long]) diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala new file mode 100644 index 000000000000..be15f32bc1b8 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Expression + +/** + * A CALL statement, as parsed from SQL. + */ +case class CallStatement(name: Seq[String], args: Seq[CallArgument]) extends LeafParsedStatement + +/** + * An argument in a CALL statement. + */ +sealed trait CallArgument { + def expr: Expression +} + +/** + * An argument in a CALL statement identified by name. + */ +case class NamedArgument(name: String, expr: Expression) extends CallArgument + +/** + * An argument in a CALL statement identified by position. + */ +case class PositionalArgument(expr: Expression) extends CallArgument diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/views/CreateIcebergView.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/views/CreateIcebergView.scala new file mode 100644 index 000000000000..9366d5efe163 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/views/CreateIcebergView.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.views + +import org.apache.spark.sql.catalyst.plans.logical.BinaryCommand +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +case class CreateIcebergView( + child: LogicalPlan, + queryText: String, + query: LogicalPlan, + columnAliases: Seq[String], + columnComments: Seq[Option[String]], + queryColumnNames: Seq[String] = Seq.empty, + comment: Option[String], + properties: Map[String, String], + allowExisting: Boolean, + replace: Boolean, + rewritten: Boolean = false) extends BinaryCommand { + override def left: LogicalPlan = child + + override def right: LogicalPlan = query + + override protected def withNewChildrenInternal( + newLeft: LogicalPlan, newRight: LogicalPlan): LogicalPlan = + copy(child = newLeft, query = newRight) +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/views/DropIcebergView.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/views/DropIcebergView.scala new file mode 100644 index 000000000000..275dba6fbf5e --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/views/DropIcebergView.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.views + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.UnaryCommand + +case class DropIcebergView( + child: LogicalPlan, + ifExists: Boolean) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): DropIcebergView = + copy(child = newChild) +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/views/ResolvedV2View.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/views/ResolvedV2View.scala new file mode 100644 index 000000000000..b9c05ff0061d --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/views/ResolvedV2View.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.views + +import org.apache.spark.sql.catalyst.analysis.LeafNodeWithoutStats +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.ViewCatalog + +case class ResolvedV2View( + catalog: ViewCatalog, + identifier: Identifier) extends LeafNodeWithoutStats { + override def output: Seq[Attribute] = Nil +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/views/ShowIcebergViews.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/views/ShowIcebergViews.scala new file mode 100644 index 000000000000..b09c27acdc16 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/views/ShowIcebergViews.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.spark.sql.catalyst.plans.logical.views + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.ShowViews +import org.apache.spark.sql.catalyst.plans.logical.UnaryCommand + +case class ShowIcebergViews( + namespace: LogicalPlan, + pattern: Option[String], + override val output: Seq[Attribute] = ShowViews.getOutputAttrs) extends UnaryCommand { + override def child: LogicalPlan = namespace + + override protected def withNewChildInternal(newChild: LogicalPlan): ShowIcebergViews = + copy(namespace = newChild) +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AddPartitionFieldExec.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AddPartitionFieldExec.scala new file mode 100644 index 000000000000..55f327f7e45e --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AddPartitionFieldExec.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.spark.Spark3Util +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.TableCatalog +import org.apache.spark.sql.connector.expressions.Transform + +case class AddPartitionFieldExec( + catalog: TableCatalog, + ident: Identifier, + transform: Transform, + name: Option[String]) extends LeafV2CommandExec { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override protected def run(): Seq[InternalRow] = { + catalog.loadTable(ident) match { + case iceberg: SparkTable => + iceberg.table.updateSpec() + .addField(name.orNull, Spark3Util.toIcebergTerm(transform)) + .commit() + + case table => + throw new UnsupportedOperationException(s"Cannot add partition field to non-Iceberg table: $table") + } + + Nil + } + + override def simpleString(maxFields: Int): String = { + s"AddPartitionField ${catalog.name}.${ident.quoted} ${name.map(n => s"$n=").getOrElse("")}${transform.describe}" + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AlterV2ViewSetPropertiesExec.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AlterV2ViewSetPropertiesExec.scala new file mode 100644 index 000000000000..b103d1ee2c58 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AlterV2ViewSetPropertiesExec.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.ViewCatalog +import org.apache.spark.sql.connector.catalog.ViewChange + + +case class AlterV2ViewSetPropertiesExec( + catalog: ViewCatalog, + ident: Identifier, + properties: Map[String, String]) extends LeafV2CommandExec { + + override lazy val output: Seq[Attribute] = Nil + + override protected def run(): Seq[InternalRow] = { + val changes = properties.map { + case (property, value) => ViewChange.setProperty(property, value) + }.toSeq + + catalog.alterView(ident, changes: _*) + + Nil + } + + override def simpleString(maxFields: Int): String = { + s"AlterV2ViewSetProperties: ${ident}" + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AlterV2ViewUnsetPropertiesExec.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AlterV2ViewUnsetPropertiesExec.scala new file mode 100644 index 000000000000..a4103fede24c --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AlterV2ViewUnsetPropertiesExec.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.IcebergAnalysisException +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.ViewCatalog +import org.apache.spark.sql.connector.catalog.ViewChange + + +case class AlterV2ViewUnsetPropertiesExec( + catalog: ViewCatalog, + ident: Identifier, + propertyKeys: Seq[String], + ifExists: Boolean) extends LeafV2CommandExec { + + override lazy val output: Seq[Attribute] = Nil + + override protected def run(): Seq[InternalRow] = { + if (!ifExists) { + propertyKeys.filterNot(catalog.loadView(ident).properties.containsKey).foreach { property => + throw new IcebergAnalysisException(s"Cannot remove property that is not set: '$property'") + } + } + + val changes = propertyKeys.map(ViewChange.removeProperty) + catalog.alterView(ident, changes: _*) + + Nil + } + + override def simpleString(maxFields: Int): String = { + s"AlterV2ViewUnsetProperties: ${ident}" + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CallExec.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CallExec.scala new file mode 100644 index 000000000000..f66962a8c453 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CallExec.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.connector.iceberg.catalog.Procedure +import scala.collection.compat.immutable.ArraySeq + +case class CallExec( + output: Seq[Attribute], + procedure: Procedure, + input: InternalRow) extends LeafV2CommandExec { + + override protected def run(): Seq[InternalRow] = { + ArraySeq.unsafeWrapArray(procedure.call(input)) + } + + override def simpleString(maxFields: Int): String = { + s"CallExec${truncatedString(output, "[", ", ", "]", maxFields)} ${procedure.description}" + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateOrReplaceBranchExec.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateOrReplaceBranchExec.scala new file mode 100644 index 000000000000..2be406e7f344 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateOrReplaceBranchExec.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.relocated.com.google.common.base.Preconditions +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.BranchOptions +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.TableCatalog + +case class CreateOrReplaceBranchExec( + catalog: TableCatalog, + ident: Identifier, + branch: String, + branchOptions: BranchOptions, + create: Boolean, + replace: Boolean, + ifNotExists: Boolean) extends LeafV2CommandExec { + + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override protected def run(): Seq[InternalRow] = { + catalog.loadTable(ident) match { + case iceberg: SparkTable => + val snapshotId: java.lang.Long = branchOptions.snapshotId + .orElse(Option(iceberg.table.currentSnapshot()).map(_.snapshotId())) + .map(java.lang.Long.valueOf) + .orNull + + val manageSnapshots = iceberg.table().manageSnapshots() + val refExists = null != iceberg.table().refs().get(branch) + + def safeCreateBranch(): Unit = { + if (snapshotId == null) { + manageSnapshots.createBranch(branch) + } else { + manageSnapshots.createBranch(branch, snapshotId) + } + } + + if (create && replace && !refExists) { + safeCreateBranch() + } else if (replace) { + Preconditions.checkArgument(snapshotId != null, + "Cannot complete replace branch operation on %s, main has no snapshot", ident) + manageSnapshots.replaceBranch(branch, snapshotId) + } else { + if (refExists && ifNotExists) { + return Nil + } + + safeCreateBranch() + } + + if (branchOptions.numSnapshots.nonEmpty) { + manageSnapshots.setMinSnapshotsToKeep(branch, branchOptions.numSnapshots.get.toInt) + } + + if (branchOptions.snapshotRetain.nonEmpty) { + manageSnapshots.setMaxSnapshotAgeMs(branch, branchOptions.snapshotRetain.get) + } + + if (branchOptions.snapshotRefRetain.nonEmpty) { + manageSnapshots.setMaxRefAgeMs(branch, branchOptions.snapshotRefRetain.get) + } + + manageSnapshots.commit() + + case table => + throw new UnsupportedOperationException(s"Cannot create or replace branch on non-Iceberg table: $table") + } + + Nil + } + + override def simpleString(maxFields: Int): String = { + s"CreateOrReplace branch: $branch for table: ${ident.quoted}" + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateOrReplaceTagExec.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateOrReplaceTagExec.scala new file mode 100644 index 000000000000..372cd7548632 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateOrReplaceTagExec.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.relocated.com.google.common.base.Preconditions +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.TagOptions +import org.apache.spark.sql.connector.catalog._ + +case class CreateOrReplaceTagExec( + catalog: TableCatalog, + ident: Identifier, + tag: String, + tagOptions: TagOptions, + create: Boolean, + replace: Boolean, + ifNotExists: Boolean) extends LeafV2CommandExec { + + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override protected def run(): Seq[InternalRow] = { + catalog.loadTable(ident) match { + case iceberg: SparkTable => + val snapshotId: java.lang.Long = tagOptions.snapshotId + .orElse(Option(iceberg.table.currentSnapshot()).map(_.snapshotId())) + .map(java.lang.Long.valueOf) + .orNull + + Preconditions.checkArgument(snapshotId != null, + "Cannot complete create or replace tag operation on %s, main has no snapshot", ident) + + val manageSnapshot = iceberg.table.manageSnapshots() + val refExists = null != iceberg.table().refs().get(tag) + + if (create && replace && !refExists) { + manageSnapshot.createTag(tag, snapshotId) + } else if (replace) { + manageSnapshot.replaceTag(tag, snapshotId) + } else { + if (refExists && ifNotExists) { + return Nil + } + + manageSnapshot.createTag(tag, snapshotId) + } + + if (tagOptions.snapshotRefRetain.nonEmpty) { + manageSnapshot.setMaxRefAgeMs(tag, tagOptions.snapshotRefRetain.get) + } + + manageSnapshot.commit() + + case table => + throw new UnsupportedOperationException(s"Cannot create tag to non-Iceberg table: $table") + } + + Nil + } + + override def simpleString(maxFields: Int): String = { + s"Create tag: $tag for table: ${ident.quoted}" + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateV2ViewExec.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateV2ViewExec.scala new file mode 100644 index 000000000000..9015fb338ea5 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateV2ViewExec.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.spark.SupportsReplaceView +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.NoSuchViewException +import org.apache.spark.sql.catalyst.analysis.ViewAlreadyExistsException +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.ViewCatalog +import org.apache.spark.sql.connector.catalog.ViewInfo +import org.apache.spark.sql.types.StructType +import scala.collection.JavaConverters._ + + +case class CreateV2ViewExec( + catalog: ViewCatalog, + ident: Identifier, + queryText: String, + viewSchema: StructType, + columnAliases: Seq[String], + columnComments: Seq[Option[String]], + queryColumnNames: Seq[String], + comment: Option[String], + properties: Map[String, String], + allowExisting: Boolean, + replace: Boolean) extends LeafV2CommandExec { + + override lazy val output: Seq[Attribute] = Nil + + override protected def run(): Seq[InternalRow] = { + val currentCatalogName = session.sessionState.catalogManager.currentCatalog.name + val currentCatalog = if (!catalog.name().equals(currentCatalogName)) currentCatalogName else null + val currentNamespace = session.sessionState.catalogManager.currentNamespace + + val engineVersion = "Spark " + org.apache.spark.SPARK_VERSION + val newProperties = properties ++ + comment.map(ViewCatalog.PROP_COMMENT -> _) + + (ViewCatalog.PROP_CREATE_ENGINE_VERSION -> engineVersion, + ViewCatalog.PROP_ENGINE_VERSION -> engineVersion) + + if (replace) { + // CREATE OR REPLACE VIEW + catalog match { + case c: SupportsReplaceView => + try { + replaceView(c, currentCatalog, currentNamespace, newProperties) + } catch { + // view might have been concurrently dropped during replace + case _: NoSuchViewException => + replaceView(c, currentCatalog, currentNamespace, newProperties) + } + case _ => + if (catalog.viewExists(ident)) { + catalog.dropView(ident) + } + + createView(currentCatalog, currentNamespace, newProperties) + } + } else { + try { + // CREATE VIEW [IF NOT EXISTS] + createView(currentCatalog, currentNamespace, newProperties) + } catch { + case _: ViewAlreadyExistsException if allowExisting => // Ignore + } + } + + Nil + } + + private def replaceView( + supportsReplaceView: SupportsReplaceView, + currentCatalog: String, + currentNamespace: Array[String], + newProperties: Map[String, String]) = { + supportsReplaceView.replaceView( + ident, + queryText, + currentCatalog, + currentNamespace, + viewSchema, + queryColumnNames.toArray, + columnAliases.toArray, + columnComments.map(c => c.orNull).toArray, + newProperties.asJava) + } + + private def createView( + currentCatalog: String, + currentNamespace: Array[String], + newProperties: Map[String, String]) = { + val viewInfo: ViewInfo = new ViewInfo( + ident, + queryText, + currentCatalog, + currentNamespace, + viewSchema, + queryColumnNames.toArray, + columnAliases.toArray, + columnComments.map(c => c.orNull).toArray, + newProperties.asJava) + catalog.createView(viewInfo) + } + + override def simpleString(maxFields: Int): String = { + s"CreateV2ViewExec: ${ident}" + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeV2ViewExec.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeV2ViewExec.scala new file mode 100644 index 000000000000..bb08fb18b2bd --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeV2ViewExec.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.util.escapeSingleQuotedString +import org.apache.spark.sql.connector.catalog.View +import org.apache.spark.sql.connector.catalog.ViewCatalog +import org.apache.spark.sql.execution.LeafExecNode +import scala.collection.JavaConverters._ + +case class DescribeV2ViewExec( + output: Seq[Attribute], + view: View, + isExtended: Boolean) extends V2CommandExec with LeafExecNode { + + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override protected def run(): Seq[InternalRow] = { + if (isExtended) { + (describeSchema :+ emptyRow) ++ describeExtended + } else { + describeSchema + } + } + + private def describeSchema: Seq[InternalRow] = + view.schema().map { column => + toCatalystRow( + column.name, + column.dataType.simpleString, + column.getComment().getOrElse("")) + } + + private def emptyRow: InternalRow = toCatalystRow("", "", "") + + private def describeExtended: Seq[InternalRow] = { + val outputColumns = view.queryColumnNames.mkString("[", ", ", "]") + val properties: Map[String, String] = view.properties.asScala.toMap -- ViewCatalog.RESERVED_PROPERTIES.asScala + val viewCatalogAndNamespace: Seq[String] = view.currentCatalog +: view.currentNamespace.toSeq + val viewProperties = properties.toSeq.sortBy(_._1).map { + case (key, value) => + s"'${escapeSingleQuotedString(key)}' = '${escapeSingleQuotedString(value)}'" + }.mkString("[", ", ", "]") + + + toCatalystRow("# Detailed View Information", "", "") :: + toCatalystRow("Comment", view.properties.getOrDefault(ViewCatalog.PROP_COMMENT, ""), "") :: + toCatalystRow("View Catalog and Namespace", viewCatalogAndNamespace.quoted, "") :: + toCatalystRow("View Query Output Columns", outputColumns, "") :: + toCatalystRow("View Properties", viewProperties, "") :: + toCatalystRow("Created By", view.properties.getOrDefault(ViewCatalog.PROP_CREATE_ENGINE_VERSION, ""), "") :: + Nil + } + + override def simpleString(maxFields: Int): String = { + s"DescribeV2ViewExec" + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropBranchExec.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropBranchExec.scala new file mode 100644 index 000000000000..ff8f1820099a --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropBranchExec.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.TableCatalog + +case class DropBranchExec( + catalog: TableCatalog, + ident: Identifier, + branch: String, + ifExists: Boolean) extends LeafV2CommandExec { + + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override protected def run(): Seq[InternalRow] = { + catalog.loadTable(ident) match { + case iceberg: SparkTable => + val ref = iceberg.table().refs().get(branch) + if (ref != null || !ifExists) { + iceberg.table().manageSnapshots().removeBranch(branch).commit() + } + + case table => + throw new UnsupportedOperationException(s"Cannot drop branch on non-Iceberg table: $table") + } + + Nil + } + + override def simpleString(maxFields: Int): String = { + s"DropBranch branch: ${branch} for table: ${ident.quoted}" + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropIdentifierFieldsExec.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropIdentifierFieldsExec.scala new file mode 100644 index 000000000000..dee778b474f9 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropIdentifierFieldsExec.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.relocated.com.google.common.base.Preconditions +import org.apache.iceberg.relocated.com.google.common.collect.Sets +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.TableCatalog + +case class DropIdentifierFieldsExec( + catalog: TableCatalog, + ident: Identifier, + fields: Seq[String]) extends LeafV2CommandExec { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override protected def run(): Seq[InternalRow] = { + catalog.loadTable(ident) match { + case iceberg: SparkTable => + val schema = iceberg.table.schema + val identifierFieldNames = Sets.newHashSet(schema.identifierFieldNames) + + for (name <- fields) { + Preconditions.checkArgument(schema.findField(name) != null, + "Cannot complete drop identifier fields operation: field %s not found", name) + Preconditions.checkArgument(identifierFieldNames.contains(name), + "Cannot complete drop identifier fields operation: %s is not an identifier field", name) + identifierFieldNames.remove(name) + } + + iceberg.table.updateSchema() + .setIdentifierFields(identifierFieldNames) + .commit(); + case table => + throw new UnsupportedOperationException(s"Cannot drop identifier fields in non-Iceberg table: $table") + } + + Nil + } + + override def simpleString(maxFields: Int): String = { + s"DropIdentifierFields ${catalog.name}.${ident.quoted} (${fields.quoted})"; + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropPartitionFieldExec.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropPartitionFieldExec.scala new file mode 100644 index 000000000000..9a153f0c004e --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropPartitionFieldExec.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.spark.Spark3Util +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.TableCatalog +import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.IdentityTransform +import org.apache.spark.sql.connector.expressions.Transform + +case class DropPartitionFieldExec( + catalog: TableCatalog, + ident: Identifier, + transform: Transform) extends LeafV2CommandExec { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override protected def run(): Seq[InternalRow] = { + catalog.loadTable(ident) match { + case iceberg: SparkTable => + val schema = iceberg.table.schema + transform match { + case IdentityTransform(FieldReference(parts)) if parts.size == 1 && schema.findField(parts.head) == null => + // the name is not present in the Iceberg schema, so it must be a partition field name, not a column name + iceberg.table.updateSpec() + .removeField(parts.head) + .commit() + + case _ => + iceberg.table.updateSpec() + .removeField(Spark3Util.toIcebergTerm(transform)) + .commit() + } + + case table => + throw new UnsupportedOperationException(s"Cannot drop partition field in non-Iceberg table: $table") + } + + Nil + } + + override def simpleString(maxFields: Int): String = { + s"DropPartitionField ${catalog.name}.${ident.quoted} ${transform.describe}" + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropTagExec.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropTagExec.scala new file mode 100644 index 000000000000..8df88765a986 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropTagExec.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.TableCatalog + +case class DropTagExec( + catalog: TableCatalog, + ident: Identifier, + tag: String, + ifExists: Boolean) extends LeafV2CommandExec { + + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override protected def run(): Seq[InternalRow] = { + catalog.loadTable(ident) match { + case iceberg: SparkTable => + val ref = iceberg.table().refs().get(tag) + if (ref != null || !ifExists) { + iceberg.table().manageSnapshots().removeTag(tag).commit() + } + + case table => + throw new UnsupportedOperationException(s"Cannot drop tag on non-Iceberg table: $table") + } + + Nil + } + + override def simpleString(maxFields: Int): String = { + s"DropTag tag: ${tag} for table: ${ident.quoted}" + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropV2ViewExec.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropV2ViewExec.scala new file mode 100644 index 000000000000..c35af1486fc7 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropV2ViewExec.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.NoSuchViewException +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.ViewCatalog + + +case class DropV2ViewExec( + catalog: ViewCatalog, + ident: Identifier, + ifExists: Boolean) extends LeafV2CommandExec { + + override lazy val output: Seq[Attribute] = Nil + + override protected def run(): Seq[InternalRow] = { + val dropped = catalog.dropView(ident) + if (!dropped && !ifExists) { + throw new NoSuchViewException(ident) + } + + Nil + } + + override def simpleString(maxFields: Int): String = { + s"DropV2View: ${ident}" + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala new file mode 100644 index 000000000000..6ee3d0f645bc --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.spark.Spark3Util +import org.apache.iceberg.spark.SparkCatalog +import org.apache.iceberg.spark.SparkSessionCatalog +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.Strategy +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.IcebergAnalysisException +import org.apache.spark.sql.catalyst.analysis.ResolvedIdentifier +import org.apache.spark.sql.catalyst.analysis.ResolvedNamespace +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.expressions.PredicateHelper +import org.apache.spark.sql.catalyst.plans.logical.AddPartitionField +import org.apache.spark.sql.catalyst.plans.logical.CreateOrReplaceBranch +import org.apache.spark.sql.catalyst.plans.logical.CreateOrReplaceTag +import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation +import org.apache.spark.sql.catalyst.plans.logical.DropBranch +import org.apache.spark.sql.catalyst.plans.logical.DropIdentifierFields +import org.apache.spark.sql.catalyst.plans.logical.DropPartitionField +import org.apache.spark.sql.catalyst.plans.logical.DropTag +import org.apache.spark.sql.catalyst.plans.logical.IcebergCall +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.OrderAwareCoalesce +import org.apache.spark.sql.catalyst.plans.logical.RenameTable +import org.apache.spark.sql.catalyst.plans.logical.ReplacePartitionField +import org.apache.spark.sql.catalyst.plans.logical.SetIdentifierFields +import org.apache.spark.sql.catalyst.plans.logical.SetViewProperties +import org.apache.spark.sql.catalyst.plans.logical.SetWriteDistributionAndOrdering +import org.apache.spark.sql.catalyst.plans.logical.ShowCreateTable +import org.apache.spark.sql.catalyst.plans.logical.ShowTableProperties +import org.apache.spark.sql.catalyst.plans.logical.UnsetViewProperties +import org.apache.spark.sql.catalyst.plans.logical.views.CreateIcebergView +import org.apache.spark.sql.catalyst.plans.logical.views.DropIcebergView +import org.apache.spark.sql.catalyst.plans.logical.views.ResolvedV2View +import org.apache.spark.sql.catalyst.plans.logical.views.ShowIcebergViews +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.TableCatalog +import org.apache.spark.sql.connector.catalog.ViewCatalog +import org.apache.spark.sql.execution.OrderAwareCoalesceExec +import org.apache.spark.sql.execution.SparkPlan +import scala.jdk.CollectionConverters._ + +case class ExtendedDataSourceV2Strategy(spark: SparkSession) extends Strategy with PredicateHelper { + + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case c @ IcebergCall(procedure, args) => + val input = buildInternalRow(args) + CallExec(c.output, procedure, input) :: Nil + + case AddPartitionField(IcebergCatalogAndIdentifier(catalog, ident), transform, name) => + AddPartitionFieldExec(catalog, ident, transform, name) :: Nil + + case CreateOrReplaceBranch( + IcebergCatalogAndIdentifier(catalog, ident), branch, branchOptions, create, replace, ifNotExists) => + CreateOrReplaceBranchExec(catalog, ident, branch, branchOptions, create, replace, ifNotExists) :: Nil + + case CreateOrReplaceTag( + IcebergCatalogAndIdentifier(catalog, ident), tag, tagOptions, create, replace, ifNotExists) => + CreateOrReplaceTagExec(catalog, ident, tag, tagOptions, create, replace, ifNotExists) :: Nil + + case DropBranch(IcebergCatalogAndIdentifier(catalog, ident), branch, ifExists) => + DropBranchExec(catalog, ident, branch, ifExists) :: Nil + + case DropTag(IcebergCatalogAndIdentifier(catalog, ident), tag, ifExists) => + DropTagExec(catalog, ident, tag, ifExists) :: Nil + + case DropPartitionField(IcebergCatalogAndIdentifier(catalog, ident), transform) => + DropPartitionFieldExec(catalog, ident, transform) :: Nil + + case ReplacePartitionField(IcebergCatalogAndIdentifier(catalog, ident), transformFrom, transformTo, name) => + ReplacePartitionFieldExec(catalog, ident, transformFrom, transformTo, name) :: Nil + + case SetIdentifierFields(IcebergCatalogAndIdentifier(catalog, ident), fields) => + SetIdentifierFieldsExec(catalog, ident, fields) :: Nil + + case DropIdentifierFields(IcebergCatalogAndIdentifier(catalog, ident), fields) => + DropIdentifierFieldsExec(catalog, ident, fields) :: Nil + + case SetWriteDistributionAndOrdering( + IcebergCatalogAndIdentifier(catalog, ident), distributionMode, ordering) => + SetWriteDistributionAndOrderingExec(catalog, ident, distributionMode, ordering) :: Nil + + case OrderAwareCoalesce(numPartitions, coalescer, child) => + OrderAwareCoalesceExec(numPartitions, coalescer, planLater(child)) :: Nil + + case RenameTable(ResolvedV2View(oldCatalog: ViewCatalog, oldIdent), newName, isView@true) => + val newIdent = Spark3Util.catalogAndIdentifier(spark, newName.toList.asJava) + if (oldCatalog.name != newIdent.catalog().name()) { + throw new IcebergAnalysisException( + s"Cannot move view between catalogs: from=${oldCatalog.name} and to=${newIdent.catalog().name()}") + } + RenameV2ViewExec(oldCatalog, oldIdent, newIdent.identifier()) :: Nil + + case DropIcebergView(ResolvedIdentifier(viewCatalog: ViewCatalog, ident), ifExists) => + DropV2ViewExec(viewCatalog, ident, ifExists) :: Nil + + case CreateIcebergView(ResolvedIdentifier(viewCatalog: ViewCatalog, ident), queryText, query, + columnAliases, columnComments, queryColumnNames, comment, properties, allowExisting, replace, _) => + CreateV2ViewExec( + catalog = viewCatalog, + ident = ident, + queryText = queryText, + columnAliases = columnAliases, + columnComments = columnComments, + queryColumnNames = queryColumnNames, + viewSchema = query.schema, + comment = comment, + properties = properties, + allowExisting = allowExisting, + replace = replace) :: Nil + + case DescribeRelation(ResolvedV2View(catalog, ident), _, isExtended, output) => + DescribeV2ViewExec(output, catalog.loadView(ident), isExtended) :: Nil + + case ShowTableProperties(ResolvedV2View(catalog, ident), propertyKey, output) => + ShowV2ViewPropertiesExec(output, catalog.loadView(ident), propertyKey) :: Nil + + case ShowIcebergViews(ResolvedNamespace(catalog: ViewCatalog, namespace, _), pattern, output) => + ShowV2ViewsExec(output, catalog, namespace, pattern) :: Nil + + case ShowCreateTable(ResolvedV2View(catalog, ident), _, output) => + ShowCreateV2ViewExec(output, catalog.loadView(ident)) :: Nil + + case SetViewProperties(ResolvedV2View(catalog, ident), properties) => + AlterV2ViewSetPropertiesExec(catalog, ident, properties) :: Nil + + case UnsetViewProperties(ResolvedV2View(catalog, ident), propertyKeys, ifExists) => + AlterV2ViewUnsetPropertiesExec(catalog, ident, propertyKeys, ifExists) :: Nil + + case _ => Nil + } + + private def buildInternalRow(exprs: Seq[Expression]): InternalRow = { + val values = new Array[Any](exprs.size) + for (index <- exprs.indices) { + values(index) = exprs(index).eval() + } + new GenericInternalRow(values) + } + + private object IcebergCatalogAndIdentifier { + def unapply(identifier: Seq[String]): Option[(TableCatalog, Identifier)] = { + val catalogAndIdentifier = Spark3Util.catalogAndIdentifier(spark, identifier.asJava) + catalogAndIdentifier.catalog match { + case icebergCatalog: SparkCatalog => + Some((icebergCatalog, catalogAndIdentifier.identifier)) + case icebergCatalog: SparkSessionCatalog[_] => + Some((icebergCatalog, catalogAndIdentifier.identifier)) + case _ => + None + } + } + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RenameV2ViewExec.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RenameV2ViewExec.scala new file mode 100644 index 000000000000..61d362044c3c --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RenameV2ViewExec.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.ViewCatalog + + +case class RenameV2ViewExec( + catalog: ViewCatalog, + oldIdent: Identifier, + newIdent: Identifier) extends LeafV2CommandExec { + + override lazy val output: Seq[Attribute] = Nil + + override protected def run(): Seq[InternalRow] = { + catalog.renameView(oldIdent, newIdent) + + Seq.empty + } + + + override def simpleString(maxFields: Int): String = { + s"RenameV2View ${oldIdent} to {newIdent}" + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplacePartitionFieldExec.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplacePartitionFieldExec.scala new file mode 100644 index 000000000000..fcae0a5defc4 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplacePartitionFieldExec.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.spark.Spark3Util +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.TableCatalog +import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.IdentityTransform +import org.apache.spark.sql.connector.expressions.Transform + +case class ReplacePartitionFieldExec( + catalog: TableCatalog, + ident: Identifier, + transformFrom: Transform, + transformTo: Transform, + name: Option[String]) extends LeafV2CommandExec { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override protected def run(): Seq[InternalRow] = { + catalog.loadTable(ident) match { + case iceberg: SparkTable => + val schema = iceberg.table.schema + transformFrom match { + case IdentityTransform(FieldReference(parts)) if parts.size == 1 && schema.findField(parts.head) == null => + // the name is not present in the Iceberg schema, so it must be a partition field name, not a column name + iceberg.table.updateSpec() + .removeField(parts.head) + .addField(name.orNull, Spark3Util.toIcebergTerm(transformTo)) + .commit() + + case _ => + iceberg.table.updateSpec() + .removeField(Spark3Util.toIcebergTerm(transformFrom)) + .addField(name.orNull, Spark3Util.toIcebergTerm(transformTo)) + .commit() + } + + case table => + throw new UnsupportedOperationException(s"Cannot replace partition field in non-Iceberg table: $table") + } + + Nil + } + + override def simpleString(maxFields: Int): String = { + s"ReplacePartitionField ${catalog.name}.${ident.quoted} ${transformFrom.describe} " + + s"with ${name.map(n => s"$n=").getOrElse("")}${transformTo.describe}" + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/SetIdentifierFieldsExec.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/SetIdentifierFieldsExec.scala new file mode 100644 index 000000000000..b50550ad38ef --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/SetIdentifierFieldsExec.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.TableCatalog +import scala.jdk.CollectionConverters._ + +case class SetIdentifierFieldsExec( + catalog: TableCatalog, + ident: Identifier, + fields: Seq[String]) extends LeafV2CommandExec { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override protected def run(): Seq[InternalRow] = { + catalog.loadTable(ident) match { + case iceberg: SparkTable => + iceberg.table.updateSchema() + .setIdentifierFields(fields.asJava) + .commit(); + case table => + throw new UnsupportedOperationException(s"Cannot set identifier fields in non-Iceberg table: $table") + } + + Nil + } + + override def simpleString(maxFields: Int): String = { + s"SetIdentifierFields ${catalog.name}.${ident.quoted} (${fields.quoted})"; + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/SetWriteDistributionAndOrderingExec.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/SetWriteDistributionAndOrderingExec.scala new file mode 100644 index 000000000000..c9004ddc5bda --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/SetWriteDistributionAndOrderingExec.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.DistributionMode +import org.apache.iceberg.NullOrder +import org.apache.iceberg.SortDirection +import org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE +import org.apache.iceberg.expressions.Term +import org.apache.iceberg.spark.SparkUtil +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.TableCatalog + +case class SetWriteDistributionAndOrderingExec( + catalog: TableCatalog, + ident: Identifier, + distributionMode: Option[DistributionMode], + sortOrder: Seq[(Term, SortDirection, NullOrder)]) extends LeafV2CommandExec { + + import CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override protected def run(): Seq[InternalRow] = { + catalog.loadTable(ident) match { + case iceberg: SparkTable => + val txn = iceberg.table.newTransaction() + + val orderBuilder = txn.replaceSortOrder().caseSensitive(SparkUtil.caseSensitive(session)) + sortOrder.foreach { + case (term, SortDirection.ASC, nullOrder) => + orderBuilder.asc(term, nullOrder) + case (term, SortDirection.DESC, nullOrder) => + orderBuilder.desc(term, nullOrder) + } + orderBuilder.commit() + + distributionMode.foreach { mode => + txn.updateProperties() + .set(WRITE_DISTRIBUTION_MODE, mode.modeName()) + .commit() + } + + txn.commitTransaction() + + case table => + throw new UnsupportedOperationException(s"Cannot set write order of non-Iceberg table: $table") + } + + Nil + } + + override def simpleString(maxFields: Int): String = { + val tableIdent = s"${catalog.name}.${ident.quoted}" + val order = sortOrder.map { + case (term, direction, nullOrder) => s"$term $direction $nullOrder" + }.mkString(", ") + s"SetWriteDistributionAndOrdering $tableIdent $distributionMode $order" + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowCreateV2ViewExec.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowCreateV2ViewExec.scala new file mode 100644 index 000000000000..3be0f150313b --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowCreateV2ViewExec.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.util.escapeSingleQuotedString +import org.apache.spark.sql.connector.catalog.View +import org.apache.spark.sql.connector.catalog.ViewCatalog +import org.apache.spark.sql.execution.LeafExecNode +import scala.collection.JavaConverters._ + +case class ShowCreateV2ViewExec(output: Seq[Attribute], view: View) + extends V2CommandExec with LeafExecNode { + + override protected def run(): Seq[InternalRow] = { + val builder = new StringBuilder + builder ++= s"CREATE VIEW ${view.name} " + showColumns(view, builder) + showComment(view, builder) + showProperties(view, builder) + builder ++= s"AS\n${view.query}\n" + + Seq(toCatalystRow(builder.toString)) + } + + private def showColumns(view: View, builder: StringBuilder): Unit = { + val columns = concatByMultiLines( + view.schema().fields + .map(x => s"${x.name}${x.getComment().map(c => s" COMMENT '$c'").getOrElse("")}")) + builder ++= columns + } + + private def showComment(view: View, builder: StringBuilder): Unit = { + Option(view.properties.get(ViewCatalog.PROP_COMMENT)) + .map("COMMENT '" + escapeSingleQuotedString(_) + "'\n") + .foreach(builder.append) + } + + private def showProperties( + view: View, + builder: StringBuilder): Unit = { + val showProps = view.properties.asScala.toMap -- ViewCatalog.RESERVED_PROPERTIES.asScala + if (showProps.nonEmpty) { + val props = conf.redactOptions(showProps).toSeq.sortBy(_._1).map { + case (key, value) => + s"'${escapeSingleQuotedString(key)}' = '${escapeSingleQuotedString(value)}'" + } + + builder ++= "TBLPROPERTIES " + builder ++= concatByMultiLines(props) + } + } + + private def concatByMultiLines(iter: Iterable[String]): String = { + iter.mkString("(\n ", ",\n ", ")\n") + } + + override def simpleString(maxFields: Int): String = { + s"ShowCreateV2ViewExec" + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowV2ViewPropertiesExec.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowV2ViewPropertiesExec.scala new file mode 100644 index 000000000000..89fafe99efc8 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowV2ViewPropertiesExec.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.View +import org.apache.spark.sql.connector.catalog.ViewCatalog +import org.apache.spark.sql.execution.LeafExecNode +import scala.collection.JavaConverters._ + +case class ShowV2ViewPropertiesExec( + output: Seq[Attribute], + view: View, + propertyKey: Option[String]) extends V2CommandExec with LeafExecNode { + + override protected def run(): Seq[InternalRow] = { + propertyKey match { + case Some(p) => + val propValue = properties.getOrElse(p, + s"View ${view.name()} does not have property: $p") + Seq(toCatalystRow(p, propValue)) + case None => + properties.map { + case (k, v) => toCatalystRow(k, v) + }.toSeq + } + } + + + private def properties = { + view.properties.asScala.toMap -- ViewCatalog.RESERVED_PROPERTIES.asScala + } + + override def simpleString(maxFields: Int): String = { + s"ShowV2ViewPropertiesExec" + } +} diff --git a/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowV2ViewsExec.scala b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowV2ViewsExec.scala new file mode 100644 index 000000000000..a0699df13090 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowV2ViewsExec.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.util.StringUtils +import org.apache.spark.sql.connector.catalog.ViewCatalog +import org.apache.spark.sql.execution.LeafExecNode +import org.apache.spark.sql.internal.SQLConf +import scala.collection.mutable.ArrayBuffer + +case class ShowV2ViewsExec( + output: Seq[Attribute], + catalog: ViewCatalog, + namespace: Seq[String], + pattern: Option[String]) extends V2CommandExec with LeafExecNode { + + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override protected def run(): Seq[InternalRow] = { + val rows = new ArrayBuffer[InternalRow]() + + // handle GLOBAL VIEWS + val globalTemp: String = SQLConf.get.globalTempDatabase + if (namespace.nonEmpty && globalTemp == namespace.head) { + pattern.map(p => session.sessionState.catalog.globalTempViewManager.listViewNames(p)) + .getOrElse(session.sessionState.catalog.globalTempViewManager.listViewNames("*")) + .map(name => rows += toCatalystRow(globalTemp, name, true)) + } else { + val views = catalog.listViews(namespace: _*) + views.map { view => + if (pattern.map(StringUtils.filterPattern(Seq(view.name()), _).nonEmpty).getOrElse(true)) { + rows += toCatalystRow(view.namespace().quoted, view.name(), false) + } + } + } + + // include TEMP VIEWS + pattern.map(p => session.sessionState.catalog.listLocalTempViews(p)) + .getOrElse(session.sessionState.catalog.listLocalTempViews("*")) + .map(v => rows += toCatalystRow(v.database.toArray.quoted, v.table, true)) + + rows.toSeq + } + + override def simpleString(maxFields: Int): String = { + s"ShowV2ViewsExec" + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/Employee.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/Employee.java new file mode 100644 index 000000000000..8918dfec6584 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/Employee.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import java.util.Objects; + +public class Employee { + private Integer id; + private String dep; + + public Employee() {} + + public Employee(Integer id, String dep) { + this.id = id; + this.dep = dep; + } + + public Integer getId() { + return id; + } + + public void setId(Integer id) { + this.id = id; + } + + public String getDep() { + return dep; + } + + public void setDep(String dep) { + this.dep = dep; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } else if (other == null || getClass() != other.getClass()) { + return false; + } + + Employee employee = (Employee) other; + return Objects.equals(id, employee.id) && Objects.equals(dep, employee.dep); + } + + @Override + public int hashCode() { + return Objects.hash(id, dep); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/ExtensionsTestBase.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/ExtensionsTestBase.java new file mode 100644 index 000000000000..578845e3da2b --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/ExtensionsTestBase.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.METASTOREURIS; + +import java.util.Random; +import java.util.concurrent.ThreadLocalRandom; +import org.apache.iceberg.CatalogUtil; +import org.apache.iceberg.hive.HiveCatalog; +import org.apache.iceberg.hive.TestHiveMetastore; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.CatalogTestBase; +import org.apache.iceberg.spark.TestBase; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.internal.SQLConf; +import org.junit.jupiter.api.BeforeAll; + +public abstract class ExtensionsTestBase extends CatalogTestBase { + + private static final Random RANDOM = ThreadLocalRandom.current(); + + @BeforeAll + public static void startMetastoreAndSpark() { + TestBase.metastore = new TestHiveMetastore(); + metastore.start(); + TestBase.hiveConf = metastore.hiveConf(); + + TestBase.spark.close(); + + TestBase.spark = + SparkSession.builder() + .master("local[2]") + .config("spark.testing", "true") + .config(SQLConf.PARTITION_OVERWRITE_MODE().key(), "dynamic") + .config("spark.sql.extensions", IcebergSparkSessionExtensions.class.getName()) + .config("spark.hadoop." + METASTOREURIS.varname, hiveConf.get(METASTOREURIS.varname)) + .config("spark.sql.shuffle.partitions", "4") + .config("spark.sql.hive.metastorePartitionPruningFallbackOnException", "true") + .config("spark.sql.legacy.respectNullabilityInTextDatasetConversion", "true") + .config( + SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), String.valueOf(RANDOM.nextBoolean())) + .enableHiveSupport() + .getOrCreate(); + + TestBase.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + + TestBase.catalog = + (HiveCatalog) + CatalogUtil.loadCatalog( + HiveCatalog.class.getName(), "hive", ImmutableMap.of(), hiveConf); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/ProcedureUtil.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/ProcedureUtil.java new file mode 100644 index 000000000000..de4acd74a7ed --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/ProcedureUtil.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.UUID; +import org.apache.iceberg.ImmutableGenericPartitionStatisticsFile; +import org.apache.iceberg.PartitionStatisticsFile; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.PositionOutputStream; + +public class ProcedureUtil { + + private ProcedureUtil() {} + + static PartitionStatisticsFile writePartitionStatsFile( + long snapshotId, String statsLocation, FileIO fileIO) { + PositionOutputStream positionOutputStream; + try { + positionOutputStream = fileIO.newOutputFile(statsLocation).create(); + positionOutputStream.close(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + + return ImmutableGenericPartitionStatisticsFile.builder() + .snapshotId(snapshotId) + .fileSizeInBytes(42L) + .path(statsLocation) + .build(); + } + + static String statsFileLocation(String tableLocation) { + String statsFileName = "stats-file-" + UUID.randomUUID(); + return tableLocation.replaceFirst("file:", "") + "/metadata/" + statsFileName; + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkPlanUtil.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkPlanUtil.java new file mode 100644 index 000000000000..830d07d86eab --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkPlanUtil.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static scala.collection.JavaConverters.seqAsJavaListConverter; + +import java.util.Collection; +import java.util.List; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.execution.CommandResultExec; +import org.apache.spark.sql.execution.SparkPlan; +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper; +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec; +import scala.PartialFunction; +import scala.collection.Seq; + +public class SparkPlanUtil { + + private static final AdaptiveSparkPlanHelper SPARK_HELPER = new AdaptiveSparkPlanHelper() {}; + + private SparkPlanUtil() {} + + public static List collectLeaves(SparkPlan plan) { + return toJavaList(SPARK_HELPER.collectLeaves(actualPlan(plan))); + } + + public static List collectBatchScans(SparkPlan plan) { + List leaves = collectLeaves(plan); + return leaves.stream() + .filter(scan -> scan instanceof BatchScanExec) + .collect(Collectors.toList()); + } + + private static SparkPlan actualPlan(SparkPlan plan) { + if (plan instanceof CommandResultExec) { + return ((CommandResultExec) plan).commandPhysicalPlan(); + } else { + return plan; + } + } + + public static List collectExprs( + SparkPlan sparkPlan, Predicate predicate) { + Seq> seq = + SPARK_HELPER.collect( + sparkPlan, + new PartialFunction>() { + @Override + public List apply(SparkPlan plan) { + List exprs = Lists.newArrayList(); + + for (Expression expr : toJavaList(plan.expressions())) { + exprs.addAll(collectExprs(expr, predicate)); + } + + return exprs; + } + + @Override + public boolean isDefinedAt(SparkPlan plan) { + return true; + } + }); + return toJavaList(seq).stream().flatMap(Collection::stream).collect(Collectors.toList()); + } + + private static List collectExprs( + Expression expression, Predicate predicate) { + Seq seq = + expression.collect( + new PartialFunction() { + @Override + public Expression apply(Expression expr) { + return expr; + } + + @Override + public boolean isDefinedAt(Expression expr) { + return predicate.test(expr); + } + }); + return toJavaList(seq); + } + + private static List toJavaList(Seq seq) { + return seqAsJavaListConverter(seq).asJava(); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java new file mode 100644 index 000000000000..7af9dfc58737 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java @@ -0,0 +1,400 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.DataOperations.DELETE; +import static org.apache.iceberg.DataOperations.OVERWRITE; +import static org.apache.iceberg.PlanningMode.DISTRIBUTED; +import static org.apache.iceberg.PlanningMode.LOCAL; +import static org.apache.iceberg.SnapshotSummary.ADDED_DELETE_FILES_PROP; +import static org.apache.iceberg.SnapshotSummary.ADDED_FILES_PROP; +import static org.apache.iceberg.SnapshotSummary.CHANGED_PARTITION_COUNT_PROP; +import static org.apache.iceberg.SnapshotSummary.DELETED_FILES_PROP; +import static org.apache.iceberg.TableProperties.DATA_PLANNING_MODE; +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.DELETE_PLANNING_MODE; +import static org.apache.iceberg.TableProperties.ORC_VECTORIZATION_ENABLED; +import static org.apache.iceberg.TableProperties.PARQUET_VECTORIZATION_ENABLED; +import static org.apache.iceberg.TableProperties.SPARK_WRITE_PARTITIONED_FANOUT_ENABLED; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_HASH; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_NONE; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_RANGE; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Collectors; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Files; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.PlanningMode; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotRef; +import org.apache.iceberg.Table; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.parquet.GenericParquetWriter; +import org.apache.iceberg.io.DataWriter; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.execution.SparkPlan; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public abstract class SparkRowLevelOperationsTestBase extends ExtensionsTestBase { + + private static final Random RANDOM = ThreadLocalRandom.current(); + + @Parameter(index = 3) + protected FileFormat fileFormat; + + @Parameter(index = 4) + protected boolean vectorized; + + @Parameter(index = 5) + protected String distributionMode; + + @Parameter(index = 6) + protected boolean fanoutEnabled; + + @Parameter(index = 7) + protected String branch; + + @Parameter(index = 8) + protected PlanningMode planningMode; + + @Parameters( + name = + "catalogName = {0}, implementation = {1}, config = {2}," + + " format = {3}, vectorized = {4}, distributionMode = {5}," + + " fanout = {6}, branch = {7}, planningMode = {8}") + public static Object[][] parameters() { + return new Object[][] { + { + "testhive", + SparkCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default"), + FileFormat.ORC, + true, + WRITE_DISTRIBUTION_MODE_NONE, + true, + SnapshotRef.MAIN_BRANCH, + LOCAL + }, + { + "testhive", + SparkCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default"), + FileFormat.PARQUET, + true, + WRITE_DISTRIBUTION_MODE_NONE, + false, + "test", + DISTRIBUTED + }, + { + "testhadoop", + SparkCatalog.class.getName(), + ImmutableMap.of("type", "hadoop"), + FileFormat.PARQUET, + RANDOM.nextBoolean(), + WRITE_DISTRIBUTION_MODE_HASH, + true, + null, + LOCAL + }, + { + "spark_catalog", + SparkSessionCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "clients", "1", + "parquet-enabled", "false", + "cache-enabled", + "false" // Spark will delete tables using v1, leaving the cache out of sync + ), + FileFormat.AVRO, + false, + WRITE_DISTRIBUTION_MODE_RANGE, + false, + "test", + DISTRIBUTED + } + }; + } + + protected abstract Map extraTableProperties(); + + protected void initTable() { + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s', '%s' '%s', '%s' '%s', '%s' '%s', '%s' '%s')", + tableName, + DEFAULT_FILE_FORMAT, + fileFormat, + WRITE_DISTRIBUTION_MODE, + distributionMode, + SPARK_WRITE_PARTITIONED_FANOUT_ENABLED, + String.valueOf(fanoutEnabled), + DATA_PLANNING_MODE, + planningMode.modeName(), + DELETE_PLANNING_MODE, + planningMode.modeName()); + + switch (fileFormat) { + case PARQUET: + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%b')", + tableName, PARQUET_VECTORIZATION_ENABLED, vectorized); + break; + case ORC: + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%b')", + tableName, ORC_VECTORIZATION_ENABLED, vectorized); + break; + case AVRO: + assertThat(vectorized).isFalse(); + break; + } + + Map props = extraTableProperties(); + props.forEach( + (prop, value) -> { + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, prop, value); + }); + } + + protected void createAndInitTable(String schema) { + createAndInitTable(schema, null); + } + + protected void createAndInitTable(String schema, String jsonData) { + createAndInitTable(schema, "", jsonData); + } + + protected void createAndInitTable(String schema, String partitioning, String jsonData) { + sql("CREATE TABLE %s (%s) USING iceberg %s", tableName, schema, partitioning); + initTable(); + + if (jsonData != null) { + try { + Dataset ds = toDS(schema, jsonData); + ds.coalesce(1).writeTo(tableName).append(); + createBranchIfNeeded(); + } catch (NoSuchTableException e) { + throw new RuntimeException("Failed to write data", e); + } + } + } + + protected void append(String table, String jsonData) { + append(table, null, jsonData); + } + + protected void append(String table, String schema, String jsonData) { + try { + Dataset ds = toDS(schema, jsonData); + ds.coalesce(1).writeTo(table).append(); + } catch (NoSuchTableException e) { + throw new RuntimeException("Failed to write data", e); + } + } + + protected void createOrReplaceView(String name, String jsonData) { + createOrReplaceView(name, null, jsonData); + } + + protected void createOrReplaceView(String name, String schema, String jsonData) { + Dataset ds = toDS(schema, jsonData); + ds.createOrReplaceTempView(name); + } + + protected void createOrReplaceView(String name, List data, Encoder encoder) { + spark.createDataset(data, encoder).createOrReplaceTempView(name); + } + + private Dataset toDS(String schema, String jsonData) { + List jsonRows = + Arrays.stream(jsonData.split("\n")) + .filter(str -> !str.trim().isEmpty()) + .collect(Collectors.toList()); + Dataset jsonDS = spark.createDataset(jsonRows, Encoders.STRING()); + + if (schema != null) { + return spark.read().schema(schema).json(jsonDS); + } else { + return spark.read().json(jsonDS); + } + } + + protected void validateDelete( + Snapshot snapshot, String changedPartitionCount, String deletedDataFiles) { + validateSnapshot(snapshot, DELETE, changedPartitionCount, deletedDataFiles, null, null); + } + + protected void validateCopyOnWrite( + Snapshot snapshot, + String changedPartitionCount, + String deletedDataFiles, + String addedDataFiles) { + String operation = null == addedDataFiles && null != deletedDataFiles ? DELETE : OVERWRITE; + validateSnapshot( + snapshot, operation, changedPartitionCount, deletedDataFiles, null, addedDataFiles); + } + + protected void validateMergeOnRead( + Snapshot snapshot, + String changedPartitionCount, + String addedDeleteFiles, + String addedDataFiles) { + String operation = null == addedDataFiles && null != addedDeleteFiles ? DELETE : OVERWRITE; + validateSnapshot( + snapshot, operation, changedPartitionCount, null, addedDeleteFiles, addedDataFiles); + } + + protected void validateSnapshot( + Snapshot snapshot, + String operation, + String changedPartitionCount, + String deletedDataFiles, + String addedDeleteFiles, + String addedDataFiles) { + assertThat(snapshot.operation()).as("Operation must match").isEqualTo(operation); + validateProperty(snapshot, CHANGED_PARTITION_COUNT_PROP, changedPartitionCount); + validateProperty(snapshot, DELETED_FILES_PROP, deletedDataFiles); + validateProperty(snapshot, ADDED_DELETE_FILES_PROP, addedDeleteFiles); + validateProperty(snapshot, ADDED_FILES_PROP, addedDataFiles); + } + + protected void validateProperty(Snapshot snapshot, String property, Set expectedValues) { + String actual = snapshot.summary().get(property); + assertThat(actual) + .as( + "Snapshot property " + + property + + " has unexpected value, actual = " + + actual + + ", expected one of : " + + String.join(",", expectedValues)) + .isIn(expectedValues); + } + + protected void validateProperty(Snapshot snapshot, String property, String expectedValue) { + if (null == expectedValue) { + assertThat(snapshot.summary()).doesNotContainKey(property); + } else { + assertThat(snapshot.summary()) + .as("Snapshot property " + property + " has unexpected value.") + .containsEntry(property, expectedValue); + } + } + + protected void sleep(long millis) { + try { + Thread.sleep(millis); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + protected DataFile writeDataFile(Table table, List records) { + try { + OutputFile file = + Files.localOutput( + temp.resolve(fileFormat.addExtension(UUID.randomUUID().toString())).toFile()); + + DataWriter dataWriter = + Parquet.writeData(file) + .forTable(table) + .createWriterFunc(GenericParquetWriter::buildWriter) + .overwrite() + .build(); + + try { + for (GenericRecord record : records) { + dataWriter.write(record); + } + } finally { + dataWriter.close(); + } + + return dataWriter.toDataFile(); + + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override + protected String commitTarget() { + return branch == null ? tableName : String.format("%s.branch_%s", tableName, branch); + } + + @Override + protected String selectTarget() { + return branch == null ? tableName : String.format("%s VERSION AS OF '%s'", tableName, branch); + } + + protected void createBranchIfNeeded() { + if (branch != null && !branch.equals(SnapshotRef.MAIN_BRANCH)) { + sql("ALTER TABLE %s CREATE BRANCH %s", tableName, branch); + } + } + + // ORC currently does not support vectorized reads with deletes + protected boolean supportsVectorization() { + return vectorized && (isParquet() || isCopyOnWrite()); + } + + private boolean isParquet() { + return fileFormat.equals(FileFormat.PARQUET); + } + + private boolean isCopyOnWrite() { + return extraTableProperties().containsValue(RowLevelOperationMode.COPY_ON_WRITE.modeName()); + } + + protected void assertAllBatchScansVectorized(SparkPlan plan) { + List batchScans = SparkPlanUtil.collectBatchScans(plan); + assertThat(batchScans).hasSizeGreaterThan(0).allMatch(SparkPlan::supportsColumnar); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java new file mode 100644 index 000000000000..920c2f55eaaf --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java @@ -0,0 +1,1148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assumptions.assumeThat; + +import java.io.File; +import java.nio.file.Path; +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import org.apache.avro.Schema; +import org.apache.avro.SchemaBuilder; +import org.apache.avro.file.DataFileWriter; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericDatumWriter; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.io.DatumWriter; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.joda.time.DateTime; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestAddFilesProcedure extends ExtensionsTestBase { + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}, formatVersion = {3}") + public static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.HIVE.catalogName(), + SparkCatalogConfig.HIVE.implementation(), + SparkCatalogConfig.HIVE.properties(), + 1 + }, + { + SparkCatalogConfig.HADOOP.catalogName(), + SparkCatalogConfig.HADOOP.implementation(), + SparkCatalogConfig.HADOOP.properties(), + 2 + }, + { + SparkCatalogConfig.SPARK.catalogName(), + SparkCatalogConfig.SPARK.implementation(), + SparkCatalogConfig.SPARK.properties(), + 2 + } + }; + } + + @Parameter(index = 3) + private int formatVersion; + + private final String sourceTableName = "source_table"; + private File fileTableDir; + + @TempDir private Path temp; + + @BeforeEach + public void setupTempDirs() { + fileTableDir = temp.toFile(); + } + + @AfterEach + public void dropTables() { + sql("DROP TABLE IF EXISTS %s PURGE", sourceTableName); + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void addDataUnpartitioned() { + createUnpartitionedFileTable("parquet"); + + createIcebergTable("id Integer, name String, dept String, subdept String"); + + List result = + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + assertOutput(result, 2L, 1L); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT * FROM %s ORDER BY id", sourceTableName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void deleteAndAddBackUnpartitioned() { + createUnpartitionedFileTable("parquet"); + + createIcebergTable("id Integer, name String, dept String, subdept String"); + + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + String deleteData = "DELETE FROM %s"; + sql(deleteData, tableName); + + List result = + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + assertOutput(result, 2L, 1L); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT * FROM %s ORDER BY id", sourceTableName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Disabled // TODO Classpath issues prevent us from actually writing to a Spark ORC table + public void addDataUnpartitionedOrc() { + createUnpartitionedFileTable("orc"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg"; + + sql(createIceberg, tableName); + + Object result = + scalarSql( + "CALL %s.system.add_files('%s', '`orc`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + assertThat(result).isEqualTo(2L); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT * FROM %s ORDER BY id", sourceTableName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void addAvroFile() throws Exception { + // Spark Session Catalog cannot load metadata tables + // with "The namespace in session catalog must have exactly one name part" + assumeThat(catalogName).isNotEqualTo("spark_catalog"); + + // Create an Avro file + + Schema schema = + SchemaBuilder.record("record") + .fields() + .requiredInt("id") + .requiredString("data") + .endRecord(); + GenericRecord record1 = new GenericData.Record(schema); + record1.put("id", 1L); + record1.put("data", "a"); + GenericRecord record2 = new GenericData.Record(schema); + record2.put("id", 2L); + record2.put("data", "b"); + File outputFile = temp.resolve("test.avro").toFile(); + + DatumWriter datumWriter = new GenericDatumWriter(schema); + DataFileWriter dataFileWriter = new DataFileWriter(datumWriter); + dataFileWriter.create(schema, outputFile); + dataFileWriter.append(record1); + dataFileWriter.append(record2); + dataFileWriter.close(); + + createIcebergTable("id Long, data String"); + + List result = + sql( + "CALL %s.system.add_files('%s', '`avro`.`%s`')", + catalogName, tableName, outputFile.getPath()); + assertOutput(result, 1L, 1L); + + List expected = Lists.newArrayList(new Object[] {1L, "a"}, new Object[] {2L, "b"}); + + assertEquals( + "Iceberg table contains correct data", + expected, + sql("SELECT * FROM %s ORDER BY id", tableName)); + + List actualRecordCount = + sql("select %s from %s.files", DataFile.RECORD_COUNT.name(), tableName); + List expectedRecordCount = Lists.newArrayList(); + expectedRecordCount.add(new Object[] {2L}); + assertEquals( + "Iceberg file metadata should have correct metadata count", + expectedRecordCount, + actualRecordCount); + } + + // TODO Adding spark-avro doesn't work in tests + @Disabled + public void addDataUnpartitionedAvro() { + createUnpartitionedFileTable("avro"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg"; + + sql(createIceberg, tableName); + + Object result = + scalarSql( + "CALL %s.system.add_files('%s', '`avro`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + assertThat(result).isEqualTo(2L); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT * FROM %s ORDER BY id", sourceTableName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void addDataUnpartitionedHive() { + createUnpartitionedHiveTable(); + + createIcebergTable("id Integer, name String, dept String, subdept String"); + + List result = + sql("CALL %s.system.add_files('%s', '%s')", catalogName, tableName, sourceTableName); + + assertOutput(result, 2L, 1L); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT * FROM %s ORDER BY id", sourceTableName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void addDataUnpartitionedExtraCol() { + createUnpartitionedFileTable("parquet"); + + createIcebergTable("id Integer, name String, dept String, subdept String, foo string"); + + List result = + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + assertOutput(result, 2L, 1L); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT * FROM %s ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void addDataUnpartitionedMissingCol() { + createUnpartitionedFileTable("parquet"); + + createIcebergTable("id Integer, name String, dept String"); + + List result = + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + assertOutput(result, 2L, 1L); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT id, name, dept FROM %s ORDER BY id", sourceTableName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void addDataPartitionedMissingCol() { + createPartitionedFileTable("parquet"); + + createIcebergTable("id Integer, name String, dept String", "PARTITIONED BY (id)"); + + List result = + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + assertOutput(result, 8L, 4L); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT id, name, dept FROM %s ORDER BY id", sourceTableName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void addDataPartitioned() { + createPartitionedFileTable("parquet"); + + createIcebergTable( + "id Integer, name String, dept String, subdept String", "PARTITIONED BY (id)"); + + List result = + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + assertOutput(result, 8L, 4L); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @Disabled // TODO Classpath issues prevent us from actually writing to a Spark ORC table + public void addDataPartitionedOrc() { + createPartitionedFileTable("orc"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)"; + + sql(createIceberg, tableName); + + Object result = + scalarSql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + assertThat(result).isEqualTo(8L); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + // TODO Adding spark-avro doesn't work in tests + @Disabled + public void addDataPartitionedAvro() { + createPartitionedFileTable("avro"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)"; + + sql(createIceberg, tableName); + + Object result = + scalarSql( + "CALL %s.system.add_files('%s', '`avro`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + assertThat(result).isEqualTo(8L); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void addDataPartitionedHive() { + createPartitionedHiveTable(); + + createIcebergTable( + "id Integer, name String, dept String, subdept String", "PARTITIONED BY (id)"); + + List result = + sql("CALL %s.system.add_files('%s', '%s')", catalogName, tableName, sourceTableName); + + assertOutput(result, 8L, 4L); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void addPartitionToPartitioned() { + createPartitionedFileTable("parquet"); + + createIcebergTable( + "id Integer, name String, dept String, subdept String", "PARTITIONED BY (id)"); + + List result = + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`', map('id', 1))", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + assertOutput(result, 2L, 1L); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s WHERE id = 1 ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void deleteAndAddBackPartitioned() { + createPartitionedFileTable("parquet"); + + createIcebergTable( + "id Integer, name String, dept String, subdept String", "PARTITIONED BY (id)"); + + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`', map('id', 1))", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + String deleteData = "DELETE FROM %s where id = 1"; + sql(deleteData, tableName); + + List result = + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`', map('id', 1))", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + assertOutput(result, 2L, 1L); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s WHERE id = 1 ORDER BY id", sourceTableName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void addPartitionToPartitionedSnapshotIdInheritanceEnabledInTwoRuns() { + createPartitionedFileTable("parquet"); + + createIcebergTable( + "id Integer, name String, dept String, subdept String", "PARTITIONED BY (id)"); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' 'true')", + tableName, TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED); + + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`', map('id', 1))", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`', map('id', 2))", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s WHERE id < 3 ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + + // verify manifest file name has uuid pattern + String manifestPath = (String) sql("select path from %s.manifests", tableName).get(0)[0]; + + Pattern uuidPattern = Pattern.compile("[a-f0-9]{8}(?:-[a-f0-9]{4}){4}[a-f0-9]{8}"); + + Matcher matcher = uuidPattern.matcher(manifestPath); + assertThat(matcher.find()).as("verify manifest path has uuid").isTrue(); + } + + @TestTemplate + public void addDataPartitionedByDateToPartitioned() { + createDatePartitionedFileTable("parquet"); + + createIcebergTable("id Integer, name String, date Date", "PARTITIONED BY (date)"); + + List result = + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`', map('date', '2021-01-01'))", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + assertOutput(result, 2L, 1L); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT id, name, date FROM %s WHERE date = '2021-01-01' ORDER BY id", sourceTableName), + sql("SELECT id, name, date FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void addDataPartitionedVerifyPartitionTypeInferredCorrectly() { + createTableWithTwoPartitions("parquet"); + + createIcebergTable( + "id Integer, name String, date Date, dept String", "PARTITIONED BY (date, dept)"); + + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`', map('date', '2021-01-01'))", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + String sqlFormat = + "SELECT id, name, dept, date FROM %s WHERE date = '2021-01-01' and dept= '01' ORDER BY id"; + assertEquals( + "Iceberg table contains correct data", + sql(sqlFormat, sourceTableName), + sql(sqlFormat, tableName)); + } + + @TestTemplate + public void addFilteredPartitionsToPartitioned() { + createCompositePartitionedTable("parquet"); + + createIcebergTable( + "id Integer, name String, dept String, subdept String", "PARTITIONED BY (id, dept)"); + + List result = + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`', map('id', 1))", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + assertOutput(result, 2L, 1L); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s WHERE id = 1 ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void addFilteredPartitionsToPartitioned2() { + createCompositePartitionedTable("parquet"); + + createIcebergTable( + "id Integer, name String, dept String, subdept String", "PARTITIONED BY (id, dept)"); + + List result = + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`', map('dept', 'hr'))", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + assertOutput(result, 6L, 3L); + + assertEquals( + "Iceberg table contains correct data", + sql( + "SELECT id, name, dept, subdept FROM %s WHERE dept = 'hr' ORDER BY id", + sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void addFilteredPartitionsToPartitionedWithNullValueFilteringOnId() { + createCompositePartitionedTableWithNullValueInPartitionColumn("parquet"); + + createIcebergTable( + "id Integer, name String, dept String, subdept String", "PARTITIONED BY (id, dept)"); + + List result = + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`', map('id', 1))", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + assertOutput(result, 2L, 1L); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s WHERE id = 1 ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void addFilteredPartitionsToPartitionedWithNullValueFilteringOnDept() { + createCompositePartitionedTableWithNullValueInPartitionColumn("parquet"); + + createIcebergTable( + "id Integer, name String, dept String, subdept String", "PARTITIONED BY (id, dept)"); + + List result = + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`', map('dept', 'hr'))", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + assertOutput(result, 6L, 3L); + + assertEquals( + "Iceberg table contains correct data", + sql( + "SELECT id, name, dept, subdept FROM %s WHERE dept = 'hr' ORDER BY id", + sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void addWeirdCaseHiveTable() { + createWeirdCaseTable(); + + createIcebergTable( + "id Integer, `naMe` String, dept String, subdept String", "PARTITIONED BY (`naMe`)"); + + List result = + sql( + "CALL %s.system.add_files('%s', '%s', map('naMe', 'John Doe'))", + catalogName, tableName, sourceTableName); + + assertOutput(result, 2L, 1L); + + /* + While we would like to use + SELECT id, `naMe`, dept, subdept FROM %s WHERE `naMe` = 'John Doe' ORDER BY id + Spark does not actually handle this pushdown correctly for hive based tables and it returns 0 records + */ + List expected = + sql("SELECT id, `naMe`, dept, subdept from %s ORDER BY id", sourceTableName).stream() + .filter(r -> r[1].equals("John Doe")) + .collect(Collectors.toList()); + + // TODO when this assert breaks Spark fixed the pushdown issue + assertThat( + sql( + "SELECT id, `naMe`, dept, subdept from %s WHERE `naMe` = 'John Doe' ORDER BY id", + sourceTableName)) + .as("If this assert breaks it means that Spark has fixed the pushdown issue") + .hasSize(0); + + // Pushdown works for iceberg + assertThat( + sql( + "SELECT id, `naMe`, dept, subdept FROM %s WHERE `naMe` = 'John Doe' ORDER BY id", + tableName)) + .as("We should be able to pushdown mixed case partition keys") + .hasSize(2); + + assertEquals( + "Iceberg table contains correct data", + expected, + sql("SELECT id, `naMe`, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void addPartitionToPartitionedHive() { + createPartitionedHiveTable(); + + createIcebergTable( + "id Integer, name String, dept String, subdept String", "PARTITIONED BY (id)"); + + List result = + sql( + "CALL %s.system.add_files('%s', '%s', map('id', 1))", + catalogName, tableName, sourceTableName); + + assertOutput(result, 2L, 1L); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s WHERE id = 1 ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void invalidDataImport() { + createPartitionedFileTable("parquet"); + + createIcebergTable("id Integer, name String, dept String, subdept String"); + + assertThatThrownBy( + () -> + scalarSql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`', map('id', 1))", + catalogName, tableName, fileTableDir.getAbsolutePath())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageStartingWith("Cannot use partition filter with an unpartitioned table"); + + assertThatThrownBy( + () -> + scalarSql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageStartingWith("Cannot add partitioned files to an unpartitioned table"); + } + + @TestTemplate + public void invalidDataImportPartitioned() { + createUnpartitionedFileTable("parquet"); + + createIcebergTable( + "id Integer, name String, dept String, subdept String", "PARTITIONED BY (id)"); + + assertThatThrownBy( + () -> + scalarSql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`', map('x', '1', 'y', '2'))", + catalogName, tableName, fileTableDir.getAbsolutePath())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageStartingWith("Cannot add data files to target table") + .hasMessageContaining("is greater than the number of partitioned columns"); + + assertThatThrownBy( + () -> + scalarSql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`', map('dept', '2'))", + catalogName, tableName, fileTableDir.getAbsolutePath())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageStartingWith("Cannot add files to target table") + .hasMessageContaining( + "specified partition filter refers to columns that are not partitioned"); + } + + @TestTemplate + public void addTwice() { + createPartitionedHiveTable(); + + createIcebergTable( + "id Integer, name String, dept String, subdept String", "PARTITIONED BY (id)"); + + List result1 = + sql( + "CALL %s.system.add_files(" + + "table => '%s', " + + "source_table => '%s', " + + "partition_filter => map('id', 1))", + catalogName, tableName, sourceTableName); + assertOutput(result1, 2L, 1L); + + List result2 = + sql( + "CALL %s.system.add_files(" + + "table => '%s', " + + "source_table => '%s', " + + "partition_filter => map('id', 2))", + catalogName, tableName, sourceTableName); + assertOutput(result2, 2L, 1L); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s WHERE id = 1 ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s WHERE id = 1 ORDER BY id", tableName)); + assertEquals( + "Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s WHERE id = 2 ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s WHERE id = 2 ORDER BY id", tableName)); + } + + @TestTemplate + public void duplicateDataPartitioned() { + createPartitionedHiveTable(); + + createIcebergTable( + "id Integer, name String, dept String, subdept String", "PARTITIONED BY (id)"); + + sql( + "CALL %s.system.add_files(" + + "table => '%s', " + + "source_table => '%s', " + + "partition_filter => map('id', 1))", + catalogName, tableName, sourceTableName); + + assertThatThrownBy( + () -> + scalarSql( + "CALL %s.system.add_files(" + + "table => '%s', " + + "source_table => '%s', " + + "partition_filter => map('id', 1))", + catalogName, tableName, sourceTableName)) + .isInstanceOf(IllegalStateException.class) + .hasMessageStartingWith( + "Cannot complete import because data files to be imported already" + + " exist within the target table"); + } + + @TestTemplate + public void duplicateDataPartitionedAllowed() { + createPartitionedHiveTable(); + + createIcebergTable( + "id Integer, name String, dept String, subdept String", "PARTITIONED BY (id)"); + + List result1 = + sql( + "CALL %s.system.add_files(" + + "table => '%s', " + + "source_table => '%s', " + + "partition_filter => map('id', 1))", + catalogName, tableName, sourceTableName); + + assertOutput(result1, 2L, 1L); + + List result2 = + sql( + "CALL %s.system.add_files(" + + "table => '%s', " + + "source_table => '%s', " + + "partition_filter => map('id', 1)," + + "check_duplicate_files => false)", + catalogName, tableName, sourceTableName); + + assertOutput(result2, 2L, 1L); + + assertEquals( + "Iceberg table contains correct data", + sql( + "SELECT id, name, dept, subdept FROM %s WHERE id = 1 UNION ALL " + + "SELECT id, name, dept, subdept FROM %s WHERE id = 1", + sourceTableName, sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s", tableName, tableName)); + } + + @TestTemplate + public void duplicateDataUnpartitioned() { + createUnpartitionedHiveTable(); + + createIcebergTable("id Integer, name String, dept String, subdept String"); + + sql("CALL %s.system.add_files('%s', '%s')", catalogName, tableName, sourceTableName); + + assertThatThrownBy( + () -> + scalarSql( + "CALL %s.system.add_files('%s', '%s')", + catalogName, tableName, sourceTableName)) + .isInstanceOf(IllegalStateException.class) + .hasMessageStartingWith( + "Cannot complete import because data files to be imported already" + + " exist within the target table"); + } + + @TestTemplate + public void duplicateDataUnpartitionedAllowed() { + createUnpartitionedHiveTable(); + + createIcebergTable("id Integer, name String, dept String, subdept String"); + + List result1 = + sql("CALL %s.system.add_files('%s', '%s')", catalogName, tableName, sourceTableName); + assertOutput(result1, 2L, 1L); + + List result2 = + sql( + "CALL %s.system.add_files(" + + "table => '%s', " + + "source_table => '%s'," + + "check_duplicate_files => false)", + catalogName, tableName, sourceTableName); + assertOutput(result2, 2L, 1L); + + assertEquals( + "Iceberg table contains correct data", + sql( + "SELECT * FROM (SELECT * FROM %s UNION ALL " + "SELECT * from %s) ORDER BY id", + sourceTableName, sourceTableName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void testEmptyImportDoesNotThrow() { + createIcebergTable("id Integer, name String, dept String, subdept String"); + + // Empty path based import + List pathResult = + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + assertOutput(pathResult, 0L, 0L); + assertEquals( + "Iceberg table contains no added data when importing from an empty path", + EMPTY_QUERY_RESULT, + sql("SELECT * FROM %s ORDER BY id", tableName)); + + // Empty table based import + String createHive = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) STORED AS parquet"; + sql(createHive, sourceTableName); + + List tableResult = + sql("CALL %s.system.add_files('%s', '%s')", catalogName, tableName, sourceTableName); + assertOutput(tableResult, 0L, 0L); + assertEquals( + "Iceberg table contains no added data when importing from an empty table", + EMPTY_QUERY_RESULT, + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void testPartitionedImportFromEmptyPartitionDoesNotThrow() { + createPartitionedHiveTable(); + + final int emptyPartitionId = 999; + // Add an empty partition to the hive table + sql( + "ALTER TABLE %s ADD PARTITION (id = '%d') LOCATION '%d'", + sourceTableName, emptyPartitionId, emptyPartitionId); + + createIcebergTable( + "id Integer, name String, dept String, subdept String", "PARTITIONED BY (id)"); + + List tableResult = + sql( + "CALL %s.system.add_files(" + + "table => '%s', " + + "source_table => '%s', " + + "partition_filter => map('id', %d))", + catalogName, tableName, sourceTableName, emptyPartitionId); + + assertOutput(tableResult, 0L, 0L); + assertEquals( + "Iceberg table contains no added data when importing from an empty table", + EMPTY_QUERY_RESULT, + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void testAddFilesWithParallelism() { + createUnpartitionedHiveTable(); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg"; + + sql(createIceberg, tableName); + + List result = + sql( + "CALL %s.system.add_files(table => '%s', source_table => '%s', parallelism => 2)", + catalogName, tableName, sourceTableName); + + assertEquals("Procedure output must match", ImmutableList.of(row(2L, 1L)), result); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT * FROM %s ORDER BY id", sourceTableName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + private static final List EMPTY_QUERY_RESULT = Lists.newArrayList(); + + private static final StructField[] STRUCT = { + new StructField("id", DataTypes.IntegerType, true, Metadata.empty()), + new StructField("name", DataTypes.StringType, true, Metadata.empty()), + new StructField("dept", DataTypes.StringType, true, Metadata.empty()), + new StructField("subdept", DataTypes.StringType, true, Metadata.empty()) + }; + + private Dataset unpartitionedDF() { + return spark + .createDataFrame( + ImmutableList.of( + RowFactory.create(1, "John Doe", "hr", "communications"), + RowFactory.create(2, "Jane Doe", "hr", "salary"), + RowFactory.create(3, "Matt Doe", "hr", "communications"), + RowFactory.create(4, "Will Doe", "facilities", "all")), + new StructType(STRUCT)) + .repartition(1); + } + + private Dataset singleNullRecordDF() { + return spark + .createDataFrame( + ImmutableList.of(RowFactory.create(null, null, null, null)), new StructType(STRUCT)) + .repartition(1); + } + + private Dataset partitionedDF() { + return unpartitionedDF().select("name", "dept", "subdept", "id"); + } + + private Dataset compositePartitionedDF() { + return unpartitionedDF().select("name", "subdept", "id", "dept"); + } + + private Dataset compositePartitionedNullRecordDF() { + return singleNullRecordDF().select("name", "subdept", "id", "dept"); + } + + private Dataset weirdColumnNamesDF() { + Dataset unpartitionedDF = unpartitionedDF(); + return unpartitionedDF.select( + unpartitionedDF.col("id"), + unpartitionedDF.col("subdept"), + unpartitionedDF.col("dept"), + unpartitionedDF.col("name").as("naMe")); + } + + private static final StructField[] DATE_STRUCT = { + new StructField("id", DataTypes.IntegerType, true, Metadata.empty()), + new StructField("name", DataTypes.StringType, true, Metadata.empty()), + new StructField("ts", DataTypes.DateType, true, Metadata.empty()), + new StructField("dept", DataTypes.StringType, true, Metadata.empty()), + }; + + private static java.sql.Date toDate(String value) { + return new java.sql.Date(DateTime.parse(value).getMillis()); + } + + private Dataset dateDF() { + return spark + .createDataFrame( + ImmutableList.of( + RowFactory.create(1, "John Doe", toDate("2021-01-01"), "01"), + RowFactory.create(2, "Jane Doe", toDate("2021-01-01"), "01"), + RowFactory.create(3, "Matt Doe", toDate("2021-01-02"), "02"), + RowFactory.create(4, "Will Doe", toDate("2021-01-02"), "02")), + new StructType(DATE_STRUCT)) + .repartition(2); + } + + private void createUnpartitionedFileTable(String format) { + String createParquet = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING %s LOCATION '%s'"; + + sql(createParquet, sourceTableName, format, fileTableDir.getAbsolutePath()); + Dataset df = unpartitionedDF(); + df.write().insertInto(sourceTableName); + df.write().insertInto(sourceTableName); + } + + private void createPartitionedFileTable(String format) { + String createParquet = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING %s PARTITIONED BY (id) " + + "LOCATION '%s'"; + + sql(createParquet, sourceTableName, format, fileTableDir.getAbsolutePath()); + + Dataset df = partitionedDF(); + df.write().insertInto(sourceTableName); + df.write().insertInto(sourceTableName); + } + + private void createCompositePartitionedTable(String format) { + String createParquet = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING %s " + + "PARTITIONED BY (id, dept) LOCATION '%s'"; + sql(createParquet, sourceTableName, format, fileTableDir.getAbsolutePath()); + + Dataset df = compositePartitionedDF(); + df.write().insertInto(sourceTableName); + df.write().insertInto(sourceTableName); + } + + private void createCompositePartitionedTableWithNullValueInPartitionColumn(String format) { + String createParquet = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING %s " + + "PARTITIONED BY (id, dept) LOCATION '%s'"; + sql(createParquet, sourceTableName, format, fileTableDir.getAbsolutePath()); + + Dataset unionedDF = + compositePartitionedDF() + .unionAll(compositePartitionedNullRecordDF()) + .select("name", "subdept", "id", "dept") + .repartition(1); + + unionedDF.write().insertInto(sourceTableName); + unionedDF.write().insertInto(sourceTableName); + } + + private void createWeirdCaseTable() { + String createParquet = + "CREATE TABLE %s (id Integer, subdept String, dept String) " + + "PARTITIONED BY (`naMe` String) STORED AS parquet"; + + sql(createParquet, sourceTableName); + + Dataset df = weirdColumnNamesDF(); + df.write().insertInto(sourceTableName); + df.write().insertInto(sourceTableName); + } + + private void createUnpartitionedHiveTable() { + String createHive = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) STORED AS parquet"; + + sql(createHive, sourceTableName); + + Dataset df = unpartitionedDF(); + df.write().insertInto(sourceTableName); + df.write().insertInto(sourceTableName); + } + + private void createPartitionedHiveTable() { + String createHive = + "CREATE TABLE %s (name String, dept String, subdept String) " + + "PARTITIONED BY (id Integer) STORED AS parquet"; + + sql(createHive, sourceTableName); + + Dataset df = partitionedDF(); + df.write().insertInto(sourceTableName); + df.write().insertInto(sourceTableName); + } + + private void createDatePartitionedFileTable(String format) { + String createParquet = + "CREATE TABLE %s (id Integer, name String, date Date) USING %s " + + "PARTITIONED BY (date) LOCATION '%s'"; + + sql(createParquet, sourceTableName, format, fileTableDir.getAbsolutePath()); + + dateDF().select("id", "name", "ts").write().insertInto(sourceTableName); + } + + private void createTableWithTwoPartitions(String format) { + String createParquet = + "CREATE TABLE %s (id Integer, name String, date Date, dept String) USING %s " + + "PARTITIONED BY (date, dept) LOCATION '%s'"; + + sql(createParquet, sourceTableName, format, fileTableDir.getAbsolutePath()); + + dateDF().write().insertInto(sourceTableName); + } + + private void createIcebergTable(String schema) { + createIcebergTable(schema, ""); + } + + private void createIcebergTable(String schema, String partitioning) { + sql( + "CREATE TABLE %s (%s) USING iceberg %s TBLPROPERTIES ('%s' '%d')", + tableName, schema, partitioning, TableProperties.FORMAT_VERSION, formatVersion); + } + + private void assertOutput( + List result, long expectedAddedFilesCount, long expectedChangedPartitionCount) { + Object[] output = Iterables.getOnlyElement(result); + assertThat(output[0]).isEqualTo(expectedAddedFilesCount); + if (formatVersion == 1) { + assertThat(output[1]).isEqualTo(expectedChangedPartitionCount); + } else { + // the number of changed partitions may not be populated in v2 tables + assertThat(output[1]).isIn(expectedChangedPartitionCount, null); + } + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAlterTablePartitionFields.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAlterTablePartitionFields.java new file mode 100644 index 000000000000..38e5c942c9ff --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAlterTablePartitionFields.java @@ -0,0 +1,585 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.TestHelpers; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.source.SparkTable; +import org.apache.spark.sql.connector.catalog.CatalogManager; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestAlterTablePartitionFields extends ExtensionsTestBase { + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}, formatVersion = {3}") + public static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.HIVE.catalogName(), + SparkCatalogConfig.HIVE.implementation(), + SparkCatalogConfig.HIVE.properties(), + 1 + }, + { + SparkCatalogConfig.SPARK.catalogName(), + SparkCatalogConfig.SPARK.implementation(), + SparkCatalogConfig.SPARK.properties(), + 2 + } + }; + } + + @Parameter(index = 3) + private int formatVersion; + + @AfterEach + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testAddIdentityPartition() { + createTable("id bigint NOT NULL, category string, ts timestamp, data string"); + Table table = validationCatalog.loadTable(tableIdent); + + assertThat(table.spec().isUnpartitioned()).as("Table should start unpartitioned").isTrue(); + + sql("ALTER TABLE %s ADD PARTITION FIELD category", tableName); + + table.refresh(); + + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()).withSpecId(1).identity("category").build(); + + assertThat(table.spec()).as("Should have new spec field").isEqualTo(expected); + } + + @TestTemplate + public void testAddBucketPartition() { + createTable("id bigint NOT NULL, category string, ts timestamp, data string"); + Table table = validationCatalog.loadTable(tableIdent); + + assertThat(table.spec().isUnpartitioned()).as("Table should start unpartitioned").isTrue(); + + sql("ALTER TABLE %s ADD PARTITION FIELD bucket(16, id)", tableName); + + table.refresh(); + + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()) + .withSpecId(1) + .bucket("id", 16, "id_bucket_16") + .build(); + + assertThat(table.spec()).as("Should have new spec field").isEqualTo(expected); + } + + @TestTemplate + public void testAddTruncatePartition() { + createTable("id bigint NOT NULL, category string, ts timestamp, data string"); + Table table = validationCatalog.loadTable(tableIdent); + + assertThat(table.spec().isUnpartitioned()).as("Table should start unpartitioned").isTrue(); + + sql("ALTER TABLE %s ADD PARTITION FIELD truncate(data, 4)", tableName); + + table.refresh(); + + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()) + .withSpecId(1) + .truncate("data", 4, "data_trunc_4") + .build(); + + assertThat(table.spec()).as("Should have new spec field").isEqualTo(expected); + } + + @TestTemplate + public void testAddYearsPartition() { + createTable("id bigint NOT NULL, category string, ts timestamp, data string"); + Table table = validationCatalog.loadTable(tableIdent); + + assertThat(table.spec().isUnpartitioned()).as("Table should start unpartitioned").isTrue(); + + sql("ALTER TABLE %s ADD PARTITION FIELD years(ts)", tableName); + + table.refresh(); + + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()).withSpecId(1).year("ts").build(); + + assertThat(table.spec()).as("Should have new spec field").isEqualTo(expected); + } + + @TestTemplate + public void testAddMonthsPartition() { + createTable("id bigint NOT NULL, category string, ts timestamp, data string"); + Table table = validationCatalog.loadTable(tableIdent); + + assertThat(table.spec().isUnpartitioned()).as("Table should start unpartitioned").isTrue(); + + sql("ALTER TABLE %s ADD PARTITION FIELD months(ts)", tableName); + + table.refresh(); + + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()).withSpecId(1).month("ts").build(); + + assertThat(table.spec()).as("Should have new spec field").isEqualTo(expected); + } + + @TestTemplate + public void testAddDaysPartition() { + createTable("id bigint NOT NULL, category string, ts timestamp, data string"); + Table table = validationCatalog.loadTable(tableIdent); + + assertThat(table.spec().isUnpartitioned()).as("Table should start unpartitioned").isTrue(); + + sql("ALTER TABLE %s ADD PARTITION FIELD days(ts)", tableName); + + table.refresh(); + + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()).withSpecId(1).day("ts").build(); + + assertThat(table.spec()).as("Should have new spec field").isEqualTo(expected); + } + + @TestTemplate + public void testAddHoursPartition() { + createTable("id bigint NOT NULL, category string, ts timestamp, data string"); + Table table = validationCatalog.loadTable(tableIdent); + + assertThat(table.spec().isUnpartitioned()).as("Table should start unpartitioned").isTrue(); + + sql("ALTER TABLE %s ADD PARTITION FIELD hours(ts)", tableName); + + table.refresh(); + + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()).withSpecId(1).hour("ts").build(); + + assertThat(table.spec()).as("Should have new spec field").isEqualTo(expected); + } + + @TestTemplate + public void testAddYearPartition() { + createTable("id bigint NOT NULL, category string, ts timestamp, data string"); + Table table = validationCatalog.loadTable(tableIdent); + + assertThat(table.spec().isUnpartitioned()).as("Table should start unpartitioned").isTrue(); + + sql("ALTER TABLE %s ADD PARTITION FIELD year(ts)", tableName); + + table.refresh(); + + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()).withSpecId(1).year("ts").build(); + + assertThat(table.spec()).as("Should have new spec field").isEqualTo(expected); + } + + @TestTemplate + public void testAddMonthPartition() { + createTable("id bigint NOT NULL, category string, ts timestamp, data string"); + Table table = validationCatalog.loadTable(tableIdent); + + assertThat(table.spec().isUnpartitioned()).as("Table should start unpartitioned").isTrue(); + + sql("ALTER TABLE %s ADD PARTITION FIELD month(ts)", tableName); + + table.refresh(); + + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()).withSpecId(1).month("ts").build(); + + assertThat(table.spec()).as("Should have new spec field").isEqualTo(expected); + } + + @TestTemplate + public void testAddDayPartition() { + createTable("id bigint NOT NULL, category string, ts timestamp, data string"); + Table table = validationCatalog.loadTable(tableIdent); + + assertThat(table.spec().isUnpartitioned()).as("Table should start unpartitioned").isTrue(); + + sql("ALTER TABLE %s ADD PARTITION FIELD day(ts)", tableName); + + table.refresh(); + + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()).withSpecId(1).day("ts").build(); + + assertThat(table.spec()).as("Should have new spec field").isEqualTo(expected); + } + + @TestTemplate + public void testAddHourPartition() { + createTable("id bigint NOT NULL, category string, ts timestamp, data string"); + Table table = validationCatalog.loadTable(tableIdent); + + assertThat(table.spec().isUnpartitioned()).as("Table should start unpartitioned").isTrue(); + + sql("ALTER TABLE %s ADD PARTITION FIELD hour(ts)", tableName); + + table.refresh(); + + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()).withSpecId(1).hour("ts").build(); + + assertThat(table.spec()).as("Should have new spec field").isEqualTo(expected); + } + + @TestTemplate + public void testAddNamedPartition() { + createTable("id bigint NOT NULL, category string, ts timestamp, data string"); + Table table = validationCatalog.loadTable(tableIdent); + + assertThat(table.spec().isUnpartitioned()).as("Table should start unpartitioned").isTrue(); + + sql("ALTER TABLE %s ADD PARTITION FIELD bucket(16, id) AS shard", tableName); + + table.refresh(); + + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()).withSpecId(1).bucket("id", 16, "shard").build(); + + assertThat(table.spec()).as("Should have new spec field").isEqualTo(expected); + } + + @TestTemplate + public void testDropIdentityPartition() { + createTable("id bigint NOT NULL, category string, data string", "category"); + Table table = validationCatalog.loadTable(tableIdent); + + assertThat(table.spec().fields()).as("Table should start with 1 partition field").hasSize(1); + + sql("ALTER TABLE %s DROP PARTITION FIELD category", tableName); + + table.refresh(); + + if (formatVersion == 1) { + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()) + .withSpecId(1) + .alwaysNull("category", "category") + .build(); + assertThat(table.spec()).as("Should have new spec field").isEqualTo(expected); + } else { + assertThat(table.spec().isUnpartitioned()).as("New spec must be unpartitioned").isTrue(); + } + } + + @TestTemplate + public void testDropDaysPartition() { + createTable("id bigint NOT NULL, ts timestamp, data string", "days(ts)"); + Table table = validationCatalog.loadTable(tableIdent); + + assertThat(table.spec().fields()).as("Table should start with 1 partition field").hasSize(1); + + sql("ALTER TABLE %s DROP PARTITION FIELD days(ts)", tableName); + + table.refresh(); + + if (formatVersion == 1) { + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()).withSpecId(1).alwaysNull("ts", "ts_day").build(); + assertThat(table.spec()).as("Should have new spec field").isEqualTo(expected); + } else { + assertThat(table.spec().isUnpartitioned()).as("New spec must be unpartitioned").isTrue(); + } + } + + @TestTemplate + public void testDropBucketPartition() { + createTable("id bigint NOT NULL, data string", "bucket(16, id)"); + Table table = validationCatalog.loadTable(tableIdent); + + assertThat(table.spec().fields()).as("Table should start with 1 partition field").hasSize(1); + + sql("ALTER TABLE %s DROP PARTITION FIELD bucket(16, id)", tableName); + + table.refresh(); + + if (formatVersion == 1) { + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()) + .withSpecId(1) + .alwaysNull("id", "id_bucket") + .build(); + assertThat(table.spec()).as("Should have new spec field").isEqualTo(expected); + } else { + assertThat(table.spec().isUnpartitioned()).as("New spec must be unpartitioned").isTrue(); + } + } + + @TestTemplate + public void testDropPartitionByName() { + createTable("id bigint NOT NULL, category string, ts timestamp, data string"); + Table table = validationCatalog.loadTable(tableIdent); + + assertThat(table.spec().isUnpartitioned()).as("Table should start unpartitioned").isTrue(); + + sql("ALTER TABLE %s ADD PARTITION FIELD bucket(16, id) AS shard", tableName); + + table.refresh(); + + assertThat(table.spec().fields()).as("Table should have 1 partition field").hasSize(1); + + // Should be recognized as iceberg command even with extra white spaces + sql("ALTER TABLE %s DROP PARTITION \n FIELD shard", tableName); + + table.refresh(); + + if (formatVersion == 1) { + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()).withSpecId(2).alwaysNull("id", "shard").build(); + assertThat(table.spec()).as("Should have new spec field").isEqualTo(expected); + } else { + assertThat(table.spec().isUnpartitioned()).as("New spec must be unpartitioned").isTrue(); + } + } + + @TestTemplate + public void testReplacePartition() { + createTable("id bigint NOT NULL, category string, ts timestamp, data string"); + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.spec().isUnpartitioned()).as("Table should start unpartitioned").isTrue(); + + sql("ALTER TABLE %s ADD PARTITION FIELD days(ts)", tableName); + table.refresh(); + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()).withSpecId(1).day("ts").build(); + assertThat(table.spec()).as("Should have new spec field").isEqualTo(expected); + + sql("ALTER TABLE %s REPLACE PARTITION FIELD days(ts) WITH hours(ts)", tableName); + table.refresh(); + if (formatVersion == 1) { + expected = + PartitionSpec.builderFor(table.schema()) + .withSpecId(2) + .alwaysNull("ts", "ts_day") + .hour("ts") + .build(); + } else { + expected = + TestHelpers.newExpectedSpecBuilder() + .withSchema(table.schema()) + .withSpecId(2) + .addField("hour", 3, 1001, "ts_hour") + .build(); + } + assertThat(table.spec()) + .as("Should changed from daily to hourly partitioned field") + .isEqualTo(expected); + } + + @TestTemplate + public void testReplacePartitionAndRename() { + createTable("id bigint NOT NULL, category string, ts timestamp, data string"); + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.spec().isUnpartitioned()).as("Table should start unpartitioned").isTrue(); + + sql("ALTER TABLE %s ADD PARTITION FIELD days(ts)", tableName); + table.refresh(); + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()).withSpecId(1).day("ts").build(); + assertThat(table.spec()).as("Should have new spec field").isEqualTo(expected); + + sql("ALTER TABLE %s REPLACE PARTITION FIELD days(ts) WITH hours(ts) AS hour_col", tableName); + table.refresh(); + if (formatVersion == 1) { + expected = + PartitionSpec.builderFor(table.schema()) + .withSpecId(2) + .alwaysNull("ts", "ts_day") + .hour("ts", "hour_col") + .build(); + } else { + expected = + TestHelpers.newExpectedSpecBuilder() + .withSchema(table.schema()) + .withSpecId(2) + .addField("hour", 3, 1001, "hour_col") + .build(); + } + assertThat(table.spec()) + .as("Should changed from daily to hourly partitioned field") + .isEqualTo(expected); + } + + @TestTemplate + public void testReplaceNamedPartition() { + createTable("id bigint NOT NULL, category string, ts timestamp, data string"); + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.spec().isUnpartitioned()).as("Table should start unpartitioned").isTrue(); + + sql("ALTER TABLE %s ADD PARTITION FIELD days(ts) AS day_col", tableName); + table.refresh(); + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()).withSpecId(1).day("ts", "day_col").build(); + assertThat(table.spec()).as("Should have new spec field").isEqualTo(expected); + + sql("ALTER TABLE %s REPLACE PARTITION FIELD day_col WITH hours(ts)", tableName); + table.refresh(); + if (formatVersion == 1) { + expected = + PartitionSpec.builderFor(table.schema()) + .withSpecId(2) + .alwaysNull("ts", "day_col") + .hour("ts") + .build(); + } else { + expected = + TestHelpers.newExpectedSpecBuilder() + .withSchema(table.schema()) + .withSpecId(2) + .addField("hour", 3, 1001, "ts_hour") + .build(); + } + assertThat(table.spec()) + .as("Should changed from daily to hourly partitioned field") + .isEqualTo(expected); + } + + @TestTemplate + public void testReplaceNamedPartitionAndRenameDifferently() { + createTable("id bigint NOT NULL, category string, ts timestamp, data string"); + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.spec().isUnpartitioned()).as("Table should start unpartitioned").isTrue(); + + sql("ALTER TABLE %s ADD PARTITION FIELD days(ts) AS day_col", tableName); + table.refresh(); + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()).withSpecId(1).day("ts", "day_col").build(); + assertThat(table.spec()).as("Should have new spec field").isEqualTo(expected); + + sql("ALTER TABLE %s REPLACE PARTITION FIELD day_col WITH hours(ts) AS hour_col", tableName); + table.refresh(); + if (formatVersion == 1) { + expected = + PartitionSpec.builderFor(table.schema()) + .withSpecId(2) + .alwaysNull("ts", "day_col") + .hour("ts", "hour_col") + .build(); + } else { + expected = + TestHelpers.newExpectedSpecBuilder() + .withSchema(table.schema()) + .withSpecId(2) + .addField("hour", 3, 1001, "hour_col") + .build(); + } + assertThat(table.spec()) + .as("Should changed from daily to hourly partitioned field") + .isEqualTo(expected); + } + + @TestTemplate + public void testSparkTableAddDropPartitions() throws Exception { + createTable("id bigint NOT NULL, ts timestamp, data string"); + assertThat(sparkTable().partitioning()).as("spark table partition should be empty").hasSize(0); + + sql("ALTER TABLE %s ADD PARTITION FIELD bucket(16, id) AS shard", tableName); + assertPartitioningEquals(sparkTable(), 1, "bucket(16, id)"); + + sql("ALTER TABLE %s ADD PARTITION FIELD truncate(data, 4)", tableName); + assertPartitioningEquals(sparkTable(), 2, "truncate(4, data)"); + + sql("ALTER TABLE %s ADD PARTITION FIELD years(ts)", tableName); + assertPartitioningEquals(sparkTable(), 3, "years(ts)"); + + sql("ALTER TABLE %s DROP PARTITION FIELD years(ts)", tableName); + assertPartitioningEquals(sparkTable(), 2, "truncate(4, data)"); + + sql("ALTER TABLE %s DROP PARTITION FIELD truncate(4, data)", tableName); + assertPartitioningEquals(sparkTable(), 1, "bucket(16, id)"); + + sql("ALTER TABLE %s DROP PARTITION FIELD shard", tableName); + sql("DESCRIBE %s", tableName); + assertThat(sparkTable().partitioning()).as("spark table partition should be empty").hasSize(0); + } + + @TestTemplate + public void testDropColumnOfOldPartitionFieldV1() { + // default table created in v1 format + sql( + "CREATE TABLE %s (id bigint NOT NULL, ts timestamp, day_of_ts date) USING iceberg PARTITIONED BY (day_of_ts) TBLPROPERTIES('format-version' = '1')", + tableName); + + sql("ALTER TABLE %s REPLACE PARTITION FIELD day_of_ts WITH days(ts)", tableName); + + sql("ALTER TABLE %s DROP COLUMN day_of_ts", tableName); + } + + @TestTemplate + public void testDropColumnOfOldPartitionFieldV2() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, ts timestamp, day_of_ts date) USING iceberg PARTITIONED BY (day_of_ts) TBLPROPERTIES('format-version' = '2')", + tableName); + + sql("ALTER TABLE %s REPLACE PARTITION FIELD day_of_ts WITH days(ts)", tableName); + + sql("ALTER TABLE %s DROP COLUMN day_of_ts", tableName); + } + + private void assertPartitioningEquals(SparkTable table, int len, String transform) { + assertThat(table.partitioning()).as("spark table partition should be " + len).hasSize(len); + assertThat(table.partitioning()[len - 1].toString()) + .as("latest spark table partition transform should match") + .isEqualTo(transform); + } + + private SparkTable sparkTable() throws Exception { + validationCatalog.loadTable(tableIdent).refresh(); + CatalogManager catalogManager = spark.sessionState().catalogManager(); + TableCatalog catalog = (TableCatalog) catalogManager.catalog(catalogName); + Identifier identifier = Identifier.of(tableIdent.namespace().levels(), tableIdent.name()); + return (SparkTable) catalog.loadTable(identifier); + } + + private void createTable(String schema) { + createTable(schema, null); + } + + private void createTable(String schema, String spec) { + if (spec == null) { + sql( + "CREATE TABLE %s (%s) USING iceberg TBLPROPERTIES ('%s' '%d')", + tableName, schema, TableProperties.FORMAT_VERSION, formatVersion); + } else { + sql( + "CREATE TABLE %s (%s) USING iceberg PARTITIONED BY (%s) TBLPROPERTIES ('%s' '%d')", + tableName, schema, spec, TableProperties.FORMAT_VERSION, formatVersion); + } + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAlterTableSchema.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAlterTableSchema.java new file mode 100644 index 000000000000..71c85b135859 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAlterTableSchema.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestAlterTableSchema extends ExtensionsTestBase { + + @AfterEach + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testSetIdentifierFields() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, " + + "location struct NOT NULL) USING iceberg", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.schema().identifierFieldIds()) + .as("Table should start without identifier") + .isEmpty(); + + sql("ALTER TABLE %s SET IDENTIFIER FIELDS id", tableName); + table.refresh(); + assertThat(table.schema().identifierFieldIds()) + .as("Should have new identifier field") + .isEqualTo(Sets.newHashSet(table.schema().findField("id").fieldId())); + + sql("ALTER TABLE %s SET IDENTIFIER FIELDS id, location.lon", tableName); + table.refresh(); + assertThat(table.schema().identifierFieldIds()) + .as("Should have new identifier field") + .isEqualTo( + Sets.newHashSet( + table.schema().findField("id").fieldId(), + table.schema().findField("location.lon").fieldId())); + + sql("ALTER TABLE %s SET IDENTIFIER FIELDS location.lon", tableName); + table.refresh(); + assertThat(table.schema().identifierFieldIds()) + .as("Should have new identifier field") + .isEqualTo(Sets.newHashSet(table.schema().findField("location.lon").fieldId())); + } + + @TestTemplate + public void testSetInvalidIdentifierFields() { + sql("CREATE TABLE %s (id bigint NOT NULL, id2 bigint) USING iceberg", tableName); + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.schema().identifierFieldIds()) + .as("Table should start without identifier") + .isEmpty(); + assertThatThrownBy(() -> sql("ALTER TABLE %s SET IDENTIFIER FIELDS unknown", tableName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageEndingWith("not found in current schema or added columns"); + + assertThatThrownBy(() -> sql("ALTER TABLE %s SET IDENTIFIER FIELDS id2", tableName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageEndingWith("not a required field"); + } + + @TestTemplate + public void testDropIdentifierFields() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, " + + "location struct NOT NULL) USING iceberg", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.schema().identifierFieldIds()) + .as("Table should start without identifier") + .isEmpty(); + + sql("ALTER TABLE %s SET IDENTIFIER FIELDS id, location.lon", tableName); + table.refresh(); + assertThat(table.schema().identifierFieldIds()) + .as("Should have new identifier fields") + .isEqualTo( + Sets.newHashSet( + table.schema().findField("id").fieldId(), + table.schema().findField("location.lon").fieldId())); + + sql("ALTER TABLE %s DROP IDENTIFIER FIELDS id", tableName); + table.refresh(); + assertThat(table.schema().identifierFieldIds()) + .as("Should removed identifier field") + .isEqualTo(Sets.newHashSet(table.schema().findField("location.lon").fieldId())); + + sql("ALTER TABLE %s SET IDENTIFIER FIELDS id, location.lon", tableName); + table.refresh(); + assertThat(table.schema().identifierFieldIds()) + .as("Should have new identifier fields") + .isEqualTo( + Sets.newHashSet( + table.schema().findField("id").fieldId(), + table.schema().findField("location.lon").fieldId())); + + sql("ALTER TABLE %s DROP IDENTIFIER FIELDS id, location.lon", tableName); + table.refresh(); + assertThat(table.schema().identifierFieldIds()) + .as("Should have no identifier field") + .isEqualTo(Sets.newHashSet()); + } + + @TestTemplate + public void testDropInvalidIdentifierFields() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string NOT NULL, " + + "location struct NOT NULL) USING iceberg", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.schema().identifierFieldIds()) + .as("Table should start without identifier") + .isEmpty(); + assertThatThrownBy(() -> sql("ALTER TABLE %s DROP IDENTIFIER FIELDS unknown", tableName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot complete drop identifier fields operation: field unknown not found"); + + sql("ALTER TABLE %s SET IDENTIFIER FIELDS id", tableName); + assertThatThrownBy(() -> sql("ALTER TABLE %s DROP IDENTIFIER FIELDS data", tableName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Cannot complete drop identifier fields operation: data is not an identifier field"); + + assertThatThrownBy(() -> sql("ALTER TABLE %s DROP IDENTIFIER FIELDS location.lon", tableName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Cannot complete drop identifier fields operation: location.lon is not an identifier field"); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAncestorsOfProcedure.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAncestorsOfProcedure.java new file mode 100644 index 000000000000..4a3a158dea52 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAncestorsOfProcedure.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.List; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.spark.sql.AnalysisException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestAncestorsOfProcedure extends ExtensionsTestBase { + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testAncestorOfUsingEmptyArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Long currentSnapshotId = table.currentSnapshot().snapshotId(); + Long currentTimestamp = table.currentSnapshot().timestampMillis(); + Long preSnapshotId = table.currentSnapshot().parentId(); + Long preTimeStamp = table.snapshot(table.currentSnapshot().parentId()).timestampMillis(); + + List output = sql("CALL %s.system.ancestors_of('%s')", catalogName, tableIdent); + + assertEquals( + "Procedure output must match", + ImmutableList.of( + row(currentSnapshotId, currentTimestamp), row(preSnapshotId, preTimeStamp)), + output); + } + + @TestTemplate + public void testAncestorOfUsingSnapshotId() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Long currentSnapshotId = table.currentSnapshot().snapshotId(); + Long currentTimestamp = table.currentSnapshot().timestampMillis(); + Long preSnapshotId = table.currentSnapshot().parentId(); + Long preTimeStamp = table.snapshot(table.currentSnapshot().parentId()).timestampMillis(); + + assertEquals( + "Procedure output must match", + ImmutableList.of( + row(currentSnapshotId, currentTimestamp), row(preSnapshotId, preTimeStamp)), + sql("CALL %s.system.ancestors_of('%s', %dL)", catalogName, tableIdent, currentSnapshotId)); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(preSnapshotId, preTimeStamp)), + sql("CALL %s.system.ancestors_of('%s', %dL)", catalogName, tableIdent, preSnapshotId)); + } + + @TestTemplate + public void testAncestorOfWithRollBack() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + Table table = validationCatalog.loadTable(tableIdent); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + table.refresh(); + Long firstSnapshotId = table.currentSnapshot().snapshotId(); + Long firstTimestamp = table.currentSnapshot().timestampMillis(); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + table.refresh(); + Long secondSnapshotId = table.currentSnapshot().snapshotId(); + Long secondTimestamp = table.currentSnapshot().timestampMillis(); + sql("INSERT INTO TABLE %s VALUES (3, 'c')", tableName); + table.refresh(); + Long thirdSnapshotId = table.currentSnapshot().snapshotId(); + Long thirdTimestamp = table.currentSnapshot().timestampMillis(); + + // roll back + sql( + "CALL %s.system.rollback_to_snapshot('%s', %dL)", + catalogName, tableIdent, secondSnapshotId); + + sql("INSERT INTO TABLE %s VALUES (4, 'd')", tableName); + table.refresh(); + Long fourthSnapshotId = table.currentSnapshot().snapshotId(); + Long fourthTimestamp = table.currentSnapshot().timestampMillis(); + + assertEquals( + "Procedure output must match", + ImmutableList.of( + row(fourthSnapshotId, fourthTimestamp), + row(secondSnapshotId, secondTimestamp), + row(firstSnapshotId, firstTimestamp)), + sql("CALL %s.system.ancestors_of('%s', %dL)", catalogName, tableIdent, fourthSnapshotId)); + + assertEquals( + "Procedure output must match", + ImmutableList.of( + row(thirdSnapshotId, thirdTimestamp), + row(secondSnapshotId, secondTimestamp), + row(firstSnapshotId, firstTimestamp)), + sql("CALL %s.system.ancestors_of('%s', %dL)", catalogName, tableIdent, thirdSnapshotId)); + } + + @TestTemplate + public void testAncestorOfUsingNamedArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Long firstSnapshotId = table.currentSnapshot().snapshotId(); + Long firstTimestamp = table.currentSnapshot().timestampMillis(); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(firstSnapshotId, firstTimestamp)), + sql( + "CALL %s.system.ancestors_of(snapshot_id => %dL, table => '%s')", + catalogName, firstSnapshotId, tableIdent)); + } + + @TestTemplate + public void testInvalidAncestorOfCases() { + assertThatThrownBy(() -> sql("CALL %s.system.ancestors_of()", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessage("Missing required parameters: [table]"); + + assertThatThrownBy(() -> sql("CALL %s.system.ancestors_of('')", catalogName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot handle an empty identifier for parameter 'table'"); + + assertThatThrownBy(() -> sql("CALL %s.system.ancestors_of('%s', 1.1)", catalogName, tableIdent)) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith("Wrong arg type for snapshot_id: cannot cast"); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestBranchDDL.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestBranchDDL.java new file mode 100644 index 000000000000..fb7f73186ad3 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestBranchDDL.java @@ -0,0 +1,395 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.List; +import java.util.concurrent.TimeUnit; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotRef; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.parser.extensions.IcebergParseException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestBranchDDL extends ExtensionsTestBase { + + @BeforeEach + public void createTable() { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + } + + @AfterEach + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + public static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.SPARK.catalogName(), + SparkCatalogConfig.SPARK.implementation(), + SparkCatalogConfig.SPARK.properties() + } + }; + } + + @TestTemplate + public void testCreateBranch() throws NoSuchTableException { + Table table = insertRows(); + long snapshotId = table.currentSnapshot().snapshotId(); + String branchName = "b1"; + Integer minSnapshotsToKeep = 2; + long maxSnapshotAge = 2L; + long maxRefAge = 10L; + sql( + "ALTER TABLE %s CREATE BRANCH %s AS OF VERSION %d RETAIN %d DAYS WITH SNAPSHOT RETENTION %d SNAPSHOTS %d days", + tableName, branchName, snapshotId, maxRefAge, minSnapshotsToKeep, maxSnapshotAge); + table.refresh(); + SnapshotRef ref = table.refs().get(branchName); + assertThat(ref.snapshotId()).isEqualTo(table.currentSnapshot().snapshotId()); + assertThat(ref.minSnapshotsToKeep()).isEqualTo(minSnapshotsToKeep); + assertThat(ref.maxSnapshotAgeMs().longValue()) + .isEqualTo(TimeUnit.DAYS.toMillis(maxSnapshotAge)); + assertThat(ref.maxRefAgeMs().longValue()).isEqualTo(TimeUnit.DAYS.toMillis(maxRefAge)); + + assertThatThrownBy(() -> sql("ALTER TABLE %s CREATE BRANCH %s", tableName, branchName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Ref b1 already exists"); + } + + @TestTemplate + public void testCreateBranchOnEmptyTable() { + String branchName = "b1"; + sql("ALTER TABLE %s CREATE BRANCH %s", tableName, "b1"); + Table table = validationCatalog.loadTable(tableIdent); + + SnapshotRef mainRef = table.refs().get(SnapshotRef.MAIN_BRANCH); + assertThat(mainRef).isNull(); + + SnapshotRef ref = table.refs().get(branchName); + assertThat(ref).isNotNull(); + assertThat(ref.minSnapshotsToKeep()).isNull(); + assertThat(ref.maxSnapshotAgeMs()).isNull(); + assertThat(ref.maxRefAgeMs()).isNull(); + + Snapshot snapshot = table.snapshot(ref.snapshotId()); + assertThat(snapshot.parentId()).isNull(); + assertThat(snapshot.addedDataFiles(table.io())).isEmpty(); + assertThat(snapshot.removedDataFiles(table.io())).isEmpty(); + assertThat(snapshot.addedDeleteFiles(table.io())).isEmpty(); + assertThat(snapshot.removedDeleteFiles(table.io())).isEmpty(); + } + + @TestTemplate + public void testCreateBranchUseDefaultConfig() throws NoSuchTableException { + Table table = insertRows(); + String branchName = "b1"; + sql("ALTER TABLE %s CREATE BRANCH %s", tableName, branchName); + table.refresh(); + SnapshotRef ref = table.refs().get(branchName); + assertThat(ref.snapshotId()).isEqualTo(table.currentSnapshot().snapshotId()); + assertThat(ref.minSnapshotsToKeep()).isNull(); + assertThat(ref.maxSnapshotAgeMs()).isNull(); + assertThat(ref.maxRefAgeMs()).isNull(); + } + + @TestTemplate + public void testCreateBranchUseCustomMinSnapshotsToKeep() throws NoSuchTableException { + Integer minSnapshotsToKeep = 2; + Table table = insertRows(); + String branchName = "b1"; + sql( + "ALTER TABLE %s CREATE BRANCH %s WITH SNAPSHOT RETENTION %d SNAPSHOTS", + tableName, branchName, minSnapshotsToKeep); + table.refresh(); + SnapshotRef ref = table.refs().get(branchName); + assertThat(ref.snapshotId()).isEqualTo(table.currentSnapshot().snapshotId()); + assertThat(ref.minSnapshotsToKeep()).isEqualTo(minSnapshotsToKeep); + assertThat(ref.maxSnapshotAgeMs()).isNull(); + assertThat(ref.maxRefAgeMs()).isNull(); + } + + @TestTemplate + public void testCreateBranchUseCustomMaxSnapshotAge() throws NoSuchTableException { + long maxSnapshotAge = 2L; + Table table = insertRows(); + String branchName = "b1"; + sql( + "ALTER TABLE %s CREATE BRANCH %s WITH SNAPSHOT RETENTION %d DAYS", + tableName, branchName, maxSnapshotAge); + table.refresh(); + SnapshotRef ref = table.refs().get(branchName); + assertThat(ref).isNotNull(); + assertThat(ref.minSnapshotsToKeep()).isNull(); + assertThat(ref.maxSnapshotAgeMs().longValue()) + .isEqualTo(TimeUnit.DAYS.toMillis(maxSnapshotAge)); + assertThat(ref.maxRefAgeMs()).isNull(); + } + + @TestTemplate + public void testCreateBranchIfNotExists() throws NoSuchTableException { + long maxSnapshotAge = 2L; + Table table = insertRows(); + String branchName = "b1"; + sql( + "ALTER TABLE %s CREATE BRANCH %s WITH SNAPSHOT RETENTION %d DAYS", + tableName, branchName, maxSnapshotAge); + sql("ALTER TABLE %s CREATE BRANCH IF NOT EXISTS %s", tableName, branchName); + + table.refresh(); + SnapshotRef ref = table.refs().get(branchName); + assertThat(ref.snapshotId()).isEqualTo(table.currentSnapshot().snapshotId()); + assertThat(ref.minSnapshotsToKeep()).isNull(); + assertThat(ref.maxSnapshotAgeMs().longValue()) + .isEqualTo(TimeUnit.DAYS.toMillis(maxSnapshotAge)); + assertThat(ref.maxRefAgeMs()).isNull(); + } + + @TestTemplate + public void testCreateBranchUseCustomMinSnapshotsToKeepAndMaxSnapshotAge() + throws NoSuchTableException { + Integer minSnapshotsToKeep = 2; + long maxSnapshotAge = 2L; + Table table = insertRows(); + String branchName = "b1"; + sql( + "ALTER TABLE %s CREATE BRANCH %s WITH SNAPSHOT RETENTION %d SNAPSHOTS %d DAYS", + tableName, branchName, minSnapshotsToKeep, maxSnapshotAge); + table.refresh(); + SnapshotRef ref = table.refs().get(branchName); + assertThat(ref.snapshotId()).isEqualTo(table.currentSnapshot().snapshotId()); + assertThat(ref.minSnapshotsToKeep()).isEqualTo(minSnapshotsToKeep); + assertThat(ref.maxSnapshotAgeMs().longValue()) + .isEqualTo(TimeUnit.DAYS.toMillis(maxSnapshotAge)); + assertThat(ref.maxRefAgeMs()).isNull(); + + assertThatThrownBy( + () -> + sql( + "ALTER TABLE %s CREATE BRANCH %s WITH SNAPSHOT RETENTION", + tableName, branchName)) + .isInstanceOf(IcebergParseException.class) + .hasMessageContaining("no viable alternative at input 'WITH SNAPSHOT RETENTION'"); + } + + @TestTemplate + public void testCreateBranchUseCustomMaxRefAge() throws NoSuchTableException { + long maxRefAge = 10L; + Table table = insertRows(); + String branchName = "b1"; + sql("ALTER TABLE %s CREATE BRANCH %s RETAIN %d DAYS", tableName, branchName, maxRefAge); + table.refresh(); + SnapshotRef ref = table.refs().get(branchName); + assertThat(ref.snapshotId()).isEqualTo(table.currentSnapshot().snapshotId()); + assertThat(ref.minSnapshotsToKeep()).isNull(); + assertThat(ref.maxSnapshotAgeMs()).isNull(); + assertThat(ref.maxRefAgeMs().longValue()).isEqualTo(TimeUnit.DAYS.toMillis(maxRefAge)); + + assertThatThrownBy(() -> sql("ALTER TABLE %s CREATE BRANCH %s RETAIN", tableName, branchName)) + .isInstanceOf(IcebergParseException.class) + .hasMessageContaining("mismatched input"); + + assertThatThrownBy( + () -> + sql("ALTER TABLE %s CREATE BRANCH %s RETAIN %s DAYS", tableName, branchName, "abc")) + .isInstanceOf(IcebergParseException.class) + .hasMessageContaining("mismatched input"); + + assertThatThrownBy( + () -> + sql( + "ALTER TABLE %s CREATE BRANCH %s RETAIN %d SECONDS", + tableName, branchName, maxRefAge)) + .isInstanceOf(IcebergParseException.class) + .hasMessageContaining("mismatched input 'SECONDS' expecting {'DAYS', 'HOURS', 'MINUTES'}"); + } + + @TestTemplate + public void testDropBranch() throws NoSuchTableException { + insertRows(); + + Table table = validationCatalog.loadTable(tableIdent); + String branchName = "b1"; + table.manageSnapshots().createBranch(branchName, table.currentSnapshot().snapshotId()).commit(); + SnapshotRef ref = table.refs().get(branchName); + assertThat(ref.snapshotId()).isEqualTo(table.currentSnapshot().snapshotId()); + + sql("ALTER TABLE %s DROP BRANCH %s", tableName, branchName); + table.refresh(); + + ref = table.refs().get(branchName); + assertThat(ref).isNull(); + } + + @TestTemplate + public void testDropBranchDoesNotExist() { + assertThatThrownBy(() -> sql("ALTER TABLE %s DROP BRANCH %s", tableName, "nonExistingBranch")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Branch does not exist: nonExistingBranch"); + } + + @TestTemplate + public void testDropBranchFailsForTag() throws NoSuchTableException { + String tagName = "b1"; + Table table = insertRows(); + table.manageSnapshots().createTag(tagName, table.currentSnapshot().snapshotId()).commit(); + + assertThatThrownBy(() -> sql("ALTER TABLE %s DROP BRANCH %s", tableName, tagName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Ref b1 is a tag not a branch"); + } + + @TestTemplate + public void testDropBranchNonConformingName() { + assertThatThrownBy(() -> sql("ALTER TABLE %s DROP BRANCH %s", tableName, "123")) + .isInstanceOf(IcebergParseException.class) + .hasMessageContaining("mismatched input '123'"); + } + + @TestTemplate + public void testDropMainBranchFails() { + assertThatThrownBy(() -> sql("ALTER TABLE %s DROP BRANCH main", tableName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot remove main branch"); + } + + @TestTemplate + public void testDropBranchIfExists() { + String branchName = "nonExistingBranch"; + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.refs().get(branchName)).isNull(); + + sql("ALTER TABLE %s DROP BRANCH IF EXISTS %s", tableName, branchName); + table.refresh(); + + SnapshotRef ref = table.refs().get(branchName); + assertThat(ref).isNull(); + } + + private Table insertRows() throws NoSuchTableException { + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + return validationCatalog.loadTable(tableIdent); + } + + @TestTemplate + public void createOrReplace() throws NoSuchTableException { + Table table = insertRows(); + long first = table.currentSnapshot().snapshotId(); + String branchName = "b1"; + insertRows(); + long second = table.currentSnapshot().snapshotId(); + table.manageSnapshots().createBranch(branchName, second).commit(); + + sql( + "ALTER TABLE %s CREATE OR REPLACE BRANCH %s AS OF VERSION %d", + tableName, branchName, first); + table.refresh(); + assertThat(table.refs().get(branchName).snapshotId()).isEqualTo(second); + } + + @TestTemplate + public void testCreateOrReplaceBranchOnEmptyTable() { + String branchName = "b1"; + sql("ALTER TABLE %s CREATE OR REPLACE BRANCH %s", tableName, "b1"); + Table table = validationCatalog.loadTable(tableIdent); + + SnapshotRef mainRef = table.refs().get(SnapshotRef.MAIN_BRANCH); + assertThat(mainRef).isNull(); + + SnapshotRef ref = table.refs().get(branchName); + assertThat(ref).isNotNull(); + assertThat(ref.minSnapshotsToKeep()).isNull(); + assertThat(ref.maxSnapshotAgeMs()).isNull(); + assertThat(ref.maxRefAgeMs()).isNull(); + + Snapshot snapshot = table.snapshot(ref.snapshotId()); + assertThat(snapshot.parentId()).isNull(); + assertThat(snapshot.addedDataFiles(table.io())).isEmpty(); + assertThat(snapshot.removedDataFiles(table.io())).isEmpty(); + assertThat(snapshot.addedDeleteFiles(table.io())).isEmpty(); + assertThat(snapshot.removedDeleteFiles(table.io())).isEmpty(); + } + + @TestTemplate + public void createOrReplaceWithNonExistingBranch() throws NoSuchTableException { + Table table = insertRows(); + String branchName = "b1"; + insertRows(); + long snapshotId = table.currentSnapshot().snapshotId(); + + sql( + "ALTER TABLE %s CREATE OR REPLACE BRANCH %s AS OF VERSION %d", + tableName, branchName, snapshotId); + table.refresh(); + assertThat(table.refs().get(branchName).snapshotId()).isEqualTo(snapshotId); + } + + @TestTemplate + public void replaceBranch() throws NoSuchTableException { + Table table = insertRows(); + long first = table.currentSnapshot().snapshotId(); + String branchName = "b1"; + long expectedMaxRefAgeMs = 1000; + table + .manageSnapshots() + .createBranch(branchName, first) + .setMaxRefAgeMs(branchName, expectedMaxRefAgeMs) + .commit(); + + insertRows(); + long second = table.currentSnapshot().snapshotId(); + + sql("ALTER TABLE %s REPLACE BRANCH %s AS OF VERSION %d", tableName, branchName, second); + table.refresh(); + SnapshotRef ref = table.refs().get(branchName); + assertThat(ref.snapshotId()).isEqualTo(second); + assertThat(ref.maxRefAgeMs()).isEqualTo(expectedMaxRefAgeMs); + } + + @TestTemplate + public void replaceBranchDoesNotExist() throws NoSuchTableException { + Table table = insertRows(); + + assertThatThrownBy( + () -> + sql( + "ALTER TABLE %s REPLACE BRANCH %s AS OF VERSION %d", + tableName, "someBranch", table.currentSnapshot().snapshotId())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Branch does not exist: someBranch"); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCallStatementParser.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCallStatementParser.java new file mode 100644 index 000000000000..ade19de36fe9 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCallStatementParser.java @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static scala.collection.JavaConverters.seqAsJavaList; + +import java.math.BigDecimal; +import java.sql.Timestamp; +import java.time.Instant; +import java.util.List; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.Literal; +import org.apache.spark.sql.catalyst.expressions.Literal$; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.apache.spark.sql.catalyst.parser.ParserInterface; +import org.apache.spark.sql.catalyst.parser.extensions.IcebergParseException; +import org.apache.spark.sql.catalyst.plans.logical.CallArgument; +import org.apache.spark.sql.catalyst.plans.logical.CallStatement; +import org.apache.spark.sql.catalyst.plans.logical.NamedArgument; +import org.apache.spark.sql.catalyst.plans.logical.PositionalArgument; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +public class TestCallStatementParser { + + private static SparkSession spark = null; + private static ParserInterface parser = null; + + @BeforeAll + public static void startSpark() { + TestCallStatementParser.spark = + SparkSession.builder() + .master("local[2]") + .config("spark.sql.extensions", IcebergSparkSessionExtensions.class.getName()) + .config("spark.extra.prop", "value") + .getOrCreate(); + TestCallStatementParser.parser = spark.sessionState().sqlParser(); + } + + @AfterAll + public static void stopSpark() { + SparkSession currentSpark = TestCallStatementParser.spark; + TestCallStatementParser.spark = null; + TestCallStatementParser.parser = null; + currentSpark.stop(); + } + + @Test + public void testDelegateUnsupportedProcedure() { + assertThatThrownBy(() -> parser.parsePlan("CALL cat.d.t()")) + .isInstanceOf(ParseException.class) + .satisfies( + exception -> { + ParseException parseException = (ParseException) exception; + assertThat(parseException.getErrorClass()).isEqualTo("PARSE_SYNTAX_ERROR"); + assertThat(parseException.getMessageParameters().get("error")).isEqualTo("'CALL'"); + }); + } + + @Test + public void testCallWithBackticks() throws ParseException { + CallStatement call = + (CallStatement) parser.parsePlan("CALL cat.`system`.`rollback_to_snapshot`()"); + assertThat(seqAsJavaList(call.name())).containsExactly("cat", "system", "rollback_to_snapshot"); + + assertThat(seqAsJavaList(call.args())).hasSize(0); + } + + @Test + public void testCallWithPositionalArgs() throws ParseException { + CallStatement call = + (CallStatement) + parser.parsePlan( + "CALL c.system.rollback_to_snapshot(1, '2', 3L, true, 1.0D, 9.0e1, 900e-1BD)"); + assertThat(seqAsJavaList(call.name())).containsExactly("c", "system", "rollback_to_snapshot"); + + assertThat(seqAsJavaList(call.args())).hasSize(7); + + checkArg(call, 0, 1, DataTypes.IntegerType); + checkArg(call, 1, "2", DataTypes.StringType); + checkArg(call, 2, 3L, DataTypes.LongType); + checkArg(call, 3, true, DataTypes.BooleanType); + checkArg(call, 4, 1.0D, DataTypes.DoubleType); + checkArg(call, 5, 9.0e1, DataTypes.DoubleType); + checkArg(call, 6, new BigDecimal("900e-1"), DataTypes.createDecimalType(3, 1)); + } + + @Test + public void testCallWithNamedArgs() throws ParseException { + CallStatement call = + (CallStatement) + parser.parsePlan( + "CALL cat.system.rollback_to_snapshot(c1 => 1, c2 => '2', c3 => true)"); + assertThat(seqAsJavaList(call.name())).containsExactly("cat", "system", "rollback_to_snapshot"); + + assertThat(seqAsJavaList(call.args())).hasSize(3); + + checkArg(call, 0, "c1", 1, DataTypes.IntegerType); + checkArg(call, 1, "c2", "2", DataTypes.StringType); + checkArg(call, 2, "c3", true, DataTypes.BooleanType); + } + + @Test + public void testCallWithMixedArgs() throws ParseException { + CallStatement call = + (CallStatement) parser.parsePlan("CALL cat.system.rollback_to_snapshot(c1 => 1, '2')"); + assertThat(seqAsJavaList(call.name())).containsExactly("cat", "system", "rollback_to_snapshot"); + + assertThat(seqAsJavaList(call.args())).hasSize(2); + + checkArg(call, 0, "c1", 1, DataTypes.IntegerType); + checkArg(call, 1, "2", DataTypes.StringType); + } + + @Test + public void testCallWithTimestampArg() throws ParseException { + CallStatement call = + (CallStatement) + parser.parsePlan( + "CALL cat.system.rollback_to_snapshot(TIMESTAMP '2017-02-03T10:37:30.00Z')"); + assertThat(seqAsJavaList(call.name())).containsExactly("cat", "system", "rollback_to_snapshot"); + + assertThat(seqAsJavaList(call.args())).hasSize(1); + + checkArg( + call, 0, Timestamp.from(Instant.parse("2017-02-03T10:37:30.00Z")), DataTypes.TimestampType); + } + + @Test + public void testCallWithVarSubstitution() throws ParseException { + CallStatement call = + (CallStatement) + parser.parsePlan("CALL cat.system.rollback_to_snapshot('${spark.extra.prop}')"); + assertThat(seqAsJavaList(call.name())).containsExactly("cat", "system", "rollback_to_snapshot"); + + assertThat(seqAsJavaList(call.args())).hasSize(1); + + checkArg(call, 0, "value", DataTypes.StringType); + } + + @Test + public void testCallParseError() { + assertThatThrownBy(() -> parser.parsePlan("CALL cat.system.rollback_to_snapshot kebab")) + .isInstanceOf(IcebergParseException.class) + .hasMessageContaining("missing '(' at 'kebab'"); + } + + @Test + public void testCallStripsComments() throws ParseException { + List callStatementsWithComments = + Lists.newArrayList( + "/* bracketed comment */ CALL cat.system.rollback_to_snapshot('${spark.extra.prop}')", + "/**/ CALL cat.system.rollback_to_snapshot('${spark.extra.prop}')", + "-- single line comment \n CALL cat.system.rollback_to_snapshot('${spark.extra.prop}')", + "-- multiple \n-- single line \n-- comments \n CALL cat.system.rollback_to_snapshot('${spark.extra.prop}')", + "/* select * from multiline_comment \n where x like '%sql%'; */ CALL cat.system.rollback_to_snapshot('${spark.extra.prop}')", + "/* {\"app\": \"dbt\", \"dbt_version\": \"1.0.1\", \"profile_name\": \"profile1\", \"target_name\": \"dev\", " + + "\"node_id\": \"model.profile1.stg_users\"} \n*/ CALL cat.system.rollback_to_snapshot('${spark.extra.prop}')", + "/* Some multi-line comment \n" + + "*/ CALL /* inline comment */ cat.system.rollback_to_snapshot('${spark.extra.prop}') -- ending comment", + "CALL -- a line ending comment\n" + + "cat.system.rollback_to_snapshot('${spark.extra.prop}')"); + for (String sqlText : callStatementsWithComments) { + CallStatement call = (CallStatement) parser.parsePlan(sqlText); + assertThat(seqAsJavaList(call.name())) + .containsExactly("cat", "system", "rollback_to_snapshot"); + + assertThat(seqAsJavaList(call.args())).hasSize(1); + + checkArg(call, 0, "value", DataTypes.StringType); + } + } + + private void checkArg( + CallStatement call, int index, Object expectedValue, DataType expectedType) { + checkArg(call, index, null, expectedValue, expectedType); + } + + private void checkArg( + CallStatement call, + int index, + String expectedName, + Object expectedValue, + DataType expectedType) { + + if (expectedName != null) { + NamedArgument arg = checkCast(call.args().apply(index), NamedArgument.class); + assertThat(arg.name()).isEqualTo(expectedName); + } else { + CallArgument arg = call.args().apply(index); + checkCast(arg, PositionalArgument.class); + } + + Expression expectedExpr = toSparkLiteral(expectedValue, expectedType); + Expression actualExpr = call.args().apply(index).expr(); + assertThat(actualExpr.dataType()).as("Arg types must match").isEqualTo(expectedExpr.dataType()); + assertThat(actualExpr).as("Arg must match").isEqualTo(expectedExpr); + } + + private Literal toSparkLiteral(Object value, DataType dataType) { + return Literal$.MODULE$.create(value, dataType); + } + + private T checkCast(Object value, Class expectedClass) { + assertThat(value).isInstanceOf(expectedClass); + return expectedClass.cast(value); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestChangelogTable.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestChangelogTable.java new file mode 100644 index 000000000000..a7ed065cae2c --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestChangelogTable.java @@ -0,0 +1,411 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.TableProperties.FORMAT_VERSION; +import static org.apache.iceberg.TableProperties.MANIFEST_MERGE_ENABLED; +import static org.apache.iceberg.TableProperties.MANIFEST_MIN_MERGE_COUNT; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.List; +import org.apache.iceberg.DataOperations; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.source.SparkChangelogTable; +import org.apache.spark.sql.DataFrameReader; +import org.apache.spark.sql.Row; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestChangelogTable extends ExtensionsTestBase { + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}, formatVersion = {3}") + public static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.SPARK.catalogName(), + SparkCatalogConfig.SPARK.implementation(), + SparkCatalogConfig.SPARK.properties(), + 1 + }, + { + SparkCatalogConfig.HIVE.catalogName(), + SparkCatalogConfig.HIVE.implementation(), + SparkCatalogConfig.HIVE.properties(), + 2 + } + }; + } + + @Parameter(index = 3) + private int formatVersion; + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testDataFilters() { + createTableWithDefaultRows(); + + sql("INSERT INTO %s VALUES (3, 'c')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + Snapshot snap3 = table.currentSnapshot(); + + sql("DELETE FROM %s WHERE id = 3", tableName); + + table.refresh(); + + Snapshot snap4 = table.currentSnapshot(); + + assertEquals( + "Should have expected rows", + ImmutableList.of( + row(3, "c", "INSERT", 2, snap3.snapshotId()), + row(3, "c", "DELETE", 3, snap4.snapshotId())), + sql("SELECT * FROM %s.changes WHERE id = 3 ORDER BY _change_ordinal, id", tableName)); + } + + @TestTemplate + public void testOverwrites() { + createTableWithDefaultRows(); + + Table table = validationCatalog.loadTable(tableIdent); + + Snapshot snap2 = table.currentSnapshot(); + + sql("INSERT OVERWRITE %s VALUES (-2, 'b')", tableName); + + table.refresh(); + + Snapshot snap3 = table.currentSnapshot(); + + assertEquals( + "Rows should match", + ImmutableList.of( + row(2, "b", "DELETE", 0, snap3.snapshotId()), + row(-2, "b", "INSERT", 0, snap3.snapshotId())), + changelogRecords(snap2, snap3)); + } + + @TestTemplate + public void testQueryWithTimeRange() { + createTable(); + + sql("INSERT INTO %s VALUES (1, 'a')", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Snapshot snap1 = table.currentSnapshot(); + long rightAfterSnap1 = waitUntilAfter(snap1.timestampMillis()); + + sql("INSERT INTO %s VALUES (2, 'b')", tableName); + table.refresh(); + Snapshot snap2 = table.currentSnapshot(); + long rightAfterSnap2 = waitUntilAfter(snap2.timestampMillis()); + + sql("INSERT OVERWRITE %s VALUES (-2, 'b')", tableName); + table.refresh(); + Snapshot snap3 = table.currentSnapshot(); + long rightAfterSnap3 = waitUntilAfter(snap3.timestampMillis()); + + assertEquals( + "Should have expected changed rows only from snapshot 3", + ImmutableList.of( + row(2, "b", "DELETE", 0, snap3.snapshotId()), + row(-2, "b", "INSERT", 0, snap3.snapshotId())), + changelogRecords(rightAfterSnap2, snap3.timestampMillis())); + + assertEquals( + "Should have expected changed rows only from snapshot 3", + ImmutableList.of( + row(2, "b", "DELETE", 0, snap3.snapshotId()), + row(-2, "b", "INSERT", 0, snap3.snapshotId())), + changelogRecords(snap2.timestampMillis(), snap3.timestampMillis())); + + assertEquals( + "Should have expected changed rows from snapshot 2 and 3", + ImmutableList.of( + row(2, "b", "INSERT", 0, snap2.snapshotId()), + row(2, "b", "DELETE", 1, snap3.snapshotId()), + row(-2, "b", "INSERT", 1, snap3.snapshotId())), + changelogRecords(rightAfterSnap1, snap3.timestampMillis())); + + assertEquals( + "Should have expected changed rows up to the current snapshot", + ImmutableList.of( + row(2, "b", "INSERT", 0, snap2.snapshotId()), + row(2, "b", "DELETE", 1, snap3.snapshotId()), + row(-2, "b", "INSERT", 1, snap3.snapshotId())), + changelogRecords(rightAfterSnap1, null)); + + assertEquals( + "Should have empty changed rows if end time is before the first snapshot", + ImmutableList.of(), + changelogRecords(null, snap1.timestampMillis() - 1)); + + assertEquals( + "Should have empty changed rows if start time is after the current snapshot", + ImmutableList.of(), + changelogRecords(rightAfterSnap3, null)); + + assertEquals( + "Should have empty changed rows if end time is before the first snapshot", + ImmutableList.of(), + changelogRecords(null, snap1.timestampMillis() - 1)); + + assertEquals( + "Should have empty changed rows if there are no snapshots between start time and end time", + ImmutableList.of(), + changelogRecords(rightAfterSnap2, snap3.timestampMillis() - 1)); + } + + @TestTemplate + public void testTimeRangeValidation() { + createTableWithDefaultRows(); + + Table table = validationCatalog.loadTable(tableIdent); + + Snapshot snap2 = table.currentSnapshot(); + + sql("INSERT OVERWRITE %s VALUES (-2, 'b')", tableName); + table.refresh(); + Snapshot snap3 = table.currentSnapshot(); + long rightAfterSnap3 = waitUntilAfter(snap3.timestampMillis()); + assertThatThrownBy(() -> changelogRecords(snap3.timestampMillis(), snap2.timestampMillis())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot set start-timestamp to be greater than end-timestamp for changelogs"); + } + + @TestTemplate + public void testMetadataDeletes() { + createTableWithDefaultRows(); + + Table table = validationCatalog.loadTable(tableIdent); + + Snapshot snap2 = table.currentSnapshot(); + + sql("DELETE FROM %s WHERE data = 'a'", tableName); + + table.refresh(); + + Snapshot snap3 = table.currentSnapshot(); + assertThat(snap3.operation()).as("Operation must match").isEqualTo(DataOperations.DELETE); + + assertEquals( + "Rows should match", + ImmutableList.of(row(1, "a", "DELETE", 0, snap3.snapshotId())), + changelogRecords(snap2, snap3)); + } + + @TestTemplate + public void testExistingEntriesInNewDataManifestsAreIgnored() { + sql( + "CREATE TABLE %s (id INT, data STRING) " + + "USING iceberg " + + "PARTITIONED BY (data) " + + "TBLPROPERTIES ( " + + " '%s' = '%d', " + + " '%s' = '1', " + + " '%s' = 'true' " + + ")", + tableName, FORMAT_VERSION, formatVersion, MANIFEST_MIN_MERGE_COUNT, MANIFEST_MERGE_ENABLED); + + sql("INSERT INTO %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + Snapshot snap1 = table.currentSnapshot(); + + sql("INSERT INTO %s VALUES (2, 'b')", tableName); + + table.refresh(); + + Snapshot snap2 = table.currentSnapshot(); + assertThat(snap2.dataManifests(table.io())).as("Manifest number must match").hasSize(1); + + assertEquals( + "Rows should match", + ImmutableList.of(row(2, "b", "INSERT", 0, snap2.snapshotId())), + changelogRecords(snap1, snap2)); + } + + @TestTemplate + public void testManifestRewritesAreIgnored() { + createTableWithDefaultRows(); + + sql("CALL %s.system.rewrite_manifests('%s')", catalogName, tableIdent); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).as("Num snapshots must match").hasSize(3); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "INSERT"), row(2, "INSERT")), + sql("SELECT id, _change_type FROM %s.changes ORDER BY id", tableName)); + } + + @TestTemplate + public void testMetadataColumns() { + createTableWithDefaultRows(); + List rows = + sql( + "SELECT id, _file, _pos, _deleted, _spec_id, _partition FROM %s.changes ORDER BY id", + tableName); + + String file1 = rows.get(0)[1].toString(); + assertThat(file1).startsWith("file:/"); + String file2 = rows.get(1)[1].toString(); + + assertEquals( + "Rows should match", + ImmutableList.of( + row(1, file1, 0L, false, 0, row("a")), row(2, file2, 0L, false, 0, row("b"))), + rows); + } + + @TestTemplate + public void testQueryWithRollback() { + createTable(); + + sql("INSERT INTO %s VALUES (1, 'a')", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Snapshot snap1 = table.currentSnapshot(); + long rightAfterSnap1 = waitUntilAfter(snap1.timestampMillis()); + + sql("INSERT INTO %s VALUES (2, 'b')", tableName); + table.refresh(); + Snapshot snap2 = table.currentSnapshot(); + long rightAfterSnap2 = waitUntilAfter(snap2.timestampMillis()); + + sql( + "CALL %s.system.rollback_to_snapshot('%s', %d)", + catalogName, tableIdent, snap1.snapshotId()); + table.refresh(); + assertThat(table.currentSnapshot()).isEqualTo(snap1); + + sql("INSERT OVERWRITE %s VALUES (-2, 'a')", tableName); + table.refresh(); + Snapshot snap3 = table.currentSnapshot(); + long rightAfterSnap3 = waitUntilAfter(snap3.timestampMillis()); + + assertEquals( + "Should have expected changed rows up to snapshot 3", + ImmutableList.of( + row(1, "a", "INSERT", 0, snap1.snapshotId()), + row(1, "a", "DELETE", 1, snap3.snapshotId()), + row(-2, "a", "INSERT", 1, snap3.snapshotId())), + changelogRecords(null, rightAfterSnap3)); + + assertEquals( + "Should have expected changed rows up to snapshot 2", + ImmutableList.of(row(1, "a", "INSERT", 0, snap1.snapshotId())), + changelogRecords(null, rightAfterSnap2)); + + assertEquals( + "Should have expected changed rows from snapshot 3 only since snapshot 2 is on a different branch.", + ImmutableList.of( + row(1, "a", "DELETE", 0, snap3.snapshotId()), + row(-2, "a", "INSERT", 0, snap3.snapshotId())), + changelogRecords(rightAfterSnap1, snap3.timestampMillis())); + + assertEquals( + "Should have expected changed rows from snapshot 3", + ImmutableList.of( + row(1, "a", "DELETE", 0, snap3.snapshotId()), + row(-2, "a", "INSERT", 0, snap3.snapshotId())), + changelogRecords(rightAfterSnap2, null)); + + sql( + "CALL %s.system.set_current_snapshot('%s', %d)", + catalogName, tableIdent, snap2.snapshotId()); + table.refresh(); + assertThat(table.currentSnapshot()).isEqualTo(snap2); + assertEquals( + "Should have expected changed rows from snapshot 2 only since snapshot 3 is on a different branch.", + ImmutableList.of(row(2, "b", "INSERT", 0, snap2.snapshotId())), + changelogRecords(rightAfterSnap1, null)); + } + + private void createTableWithDefaultRows() { + createTable(); + insertDefaultRows(); + } + + private void createTable() { + sql( + "CREATE TABLE %s (id INT, data STRING) " + + "USING iceberg " + + "PARTITIONED BY (data) " + + "TBLPROPERTIES ( " + + " '%s' = '%d' " + + ")", + tableName, FORMAT_VERSION, formatVersion); + } + + private void insertDefaultRows() { + sql("INSERT INTO %s VALUES (1, 'a')", tableName); + sql("INSERT INTO %s VALUES (2, 'b')", tableName); + } + + private List changelogRecords(Snapshot startSnapshot, Snapshot endSnapshot) { + DataFrameReader reader = spark.read(); + + if (startSnapshot != null) { + reader = reader.option(SparkReadOptions.START_SNAPSHOT_ID, startSnapshot.snapshotId()); + } + + if (endSnapshot != null) { + reader = reader.option(SparkReadOptions.END_SNAPSHOT_ID, endSnapshot.snapshotId()); + } + + return rowsToJava(collect(reader)); + } + + private List changelogRecords(Long startTimestamp, Long endTimeStamp) { + DataFrameReader reader = spark.read(); + + if (startTimestamp != null) { + reader = reader.option(SparkReadOptions.START_TIMESTAMP, startTimestamp); + } + + if (endTimeStamp != null) { + reader = reader.option(SparkReadOptions.END_TIMESTAMP, endTimeStamp); + } + + return rowsToJava(collect(reader)); + } + + private List collect(DataFrameReader reader) { + return reader + .table(tableName + "." + SparkChangelogTable.TABLE_NAME) + .orderBy("_change_ordinal", "_commit_snapshot_id", "_change_type", "id") + .collectAsList(); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCherrypickSnapshotProcedure.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCherrypickSnapshotProcedure.java new file mode 100644 index 000000000000..08b0754df43d --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCherrypickSnapshotProcedure.java @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.TableProperties.WRITE_AUDIT_PUBLISH_ENABLED; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +import java.util.List; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestCherrypickSnapshotProcedure extends ExtensionsTestBase { + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testCherrypickSnapshotUsingPositionalArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'true')", tableName, WRITE_AUDIT_PUBLISH_ENABLED); + + spark.conf().set("spark.wap.id", "1"); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should not see rows from staged snapshot", + ImmutableList.of(), + sql("SELECT * FROM %s", tableName)); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot wapSnapshot = Iterables.getOnlyElement(table.snapshots()); + + List output = + sql( + "CALL %s.system.cherrypick_snapshot('%s', %dL)", + catalogName, tableIdent, wapSnapshot.snapshotId()); + + table.refresh(); + + Snapshot currentSnapshot = table.currentSnapshot(); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(wapSnapshot.snapshotId(), currentSnapshot.snapshotId())), + output); + + assertEquals( + "Cherrypick must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s", tableName)); + } + + @TestTemplate + public void testCherrypickSnapshotUsingNamedArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'true')", tableName, WRITE_AUDIT_PUBLISH_ENABLED); + + spark.conf().set("spark.wap.id", "1"); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should not see rows from staged snapshot", + ImmutableList.of(), + sql("SELECT * FROM %s", tableName)); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot wapSnapshot = Iterables.getOnlyElement(table.snapshots()); + + List output = + sql( + "CALL %s.system.cherrypick_snapshot(snapshot_id => %dL, table => '%s')", + catalogName, wapSnapshot.snapshotId(), tableIdent); + + table.refresh(); + + Snapshot currentSnapshot = table.currentSnapshot(); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(wapSnapshot.snapshotId(), currentSnapshot.snapshotId())), + output); + + assertEquals( + "Cherrypick must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s", tableName)); + } + + @TestTemplate + public void testCherrypickSnapshotRefreshesRelationCache() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'true')", tableName, WRITE_AUDIT_PUBLISH_ENABLED); + + Dataset query = spark.sql("SELECT * FROM " + tableName + " WHERE id = 1"); + query.createOrReplaceTempView("tmp"); + + spark.sql("CACHE TABLE tmp"); + + assertEquals("View should not produce rows", ImmutableList.of(), sql("SELECT * FROM tmp")); + + spark.conf().set("spark.wap.id", "1"); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should not see rows from staged snapshot", + ImmutableList.of(), + sql("SELECT * FROM %s", tableName)); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot wapSnapshot = Iterables.getOnlyElement(table.snapshots()); + + sql( + "CALL %s.system.cherrypick_snapshot('%s', %dL)", + catalogName, tableIdent, wapSnapshot.snapshotId()); + + assertEquals( + "Cherrypick snapshot should be visible", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM tmp")); + + sql("UNCACHE TABLE tmp"); + } + + @TestTemplate + public void testCherrypickInvalidSnapshot() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + assertThatThrownBy( + () -> sql("CALL %s.system.cherrypick_snapshot('%s', -1L)", catalogName, tableIdent)) + .isInstanceOf(ValidationException.class) + .hasMessage("Cannot cherry-pick unknown snapshot ID: -1"); + } + + @TestTemplate + public void testInvalidCherrypickSnapshotCases() { + assertThatThrownBy( + () -> sql("CALL %s.system.cherrypick_snapshot('n', table => 't', 1L)", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessage("Named and positional arguments cannot be mixed"); + + assertThatThrownBy(() -> sql("CALL %s.custom.cherrypick_snapshot('n', 't', 1L)", catalogName)) + .isInstanceOf(ParseException.class) + .satisfies( + exception -> { + ParseException parseException = (ParseException) exception; + assertThat(parseException.getErrorClass()).isEqualTo("PARSE_SYNTAX_ERROR"); + assertThat(parseException.getMessageParameters().get("error")).isEqualTo("'CALL'"); + }); + + assertThatThrownBy(() -> sql("CALL %s.system.cherrypick_snapshot('t')", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessage("Missing required parameters: [snapshot_id]"); + + assertThatThrownBy(() -> sql("CALL %s.system.cherrypick_snapshot('', 1L)", catalogName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot handle an empty identifier for argument table"); + + assertThatThrownBy(() -> sql("CALL %s.system.cherrypick_snapshot('t', 2.2)", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith("Wrong arg type for snapshot_id: cannot cast"); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestConflictValidation.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestConflictValidation.java new file mode 100644 index 000000000000..b5ba7eec1b01 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestConflictValidation.java @@ -0,0 +1,394 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.List; +import org.apache.iceberg.IsolationLevel; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.functions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestConflictValidation extends ExtensionsTestBase { + + @BeforeEach + public void createTables() { + sql( + "CREATE TABLE %s (id int, data string) USING iceberg " + + "PARTITIONED BY (id)" + + "TBLPROPERTIES" + + "('format-version'='2'," + + "'write.delete.mode'='merge-on-read')", + tableName); + sql("INSERT INTO %s VALUES (1, 'a'), (2, 'b'), (3, 'c')", tableName); + } + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testOverwriteFilterSerializableIsolation() throws Exception { + Table table = validationCatalog.loadTable(tableIdent); + long snapshotId = table.currentSnapshot().snapshotId(); + + List records = Lists.newArrayList(new SimpleRecord(1, "a")); + spark.createDataFrame(records, SimpleRecord.class).writeTo(tableName).append(); + + // Validating from previous snapshot finds conflicts + Dataset conflictingDf = spark.createDataFrame(records, SimpleRecord.class); + assertThatThrownBy( + () -> + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(snapshotId)) + .option( + SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwrite(functions.col("id").equalTo(1))) + .isInstanceOf(ValidationException.class) + .hasMessageStartingWith( + "Found conflicting files that can contain records matching ref(name=\"id\") == 1:"); + + // Validating from latest snapshot should succeed + table.refresh(); + long newSnapshotId = table.currentSnapshot().snapshotId(); + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(newSnapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwrite(functions.col("id").equalTo(1)); + } + + @TestTemplate + public void testOverwriteFilterSerializableIsolation2() throws Exception { + List records = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(1, "b")); + spark.createDataFrame(records, SimpleRecord.class).coalesce(1).writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + long snapshotId = table.currentSnapshot().snapshotId(); + + // This should add a delete file + sql("DELETE FROM %s WHERE id='1' and data='b'", tableName); + table.refresh(); + + // Validating from previous snapshot finds conflicts + List conflictingRecords = Lists.newArrayList(new SimpleRecord(1, "a")); + Dataset conflictingDf = spark.createDataFrame(conflictingRecords, SimpleRecord.class); + assertThatThrownBy( + () -> + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(snapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SNAPSHOT.toString()) + .overwrite(functions.col("id").equalTo(1))) + .isInstanceOf(ValidationException.class) + .hasMessageStartingWith( + "Found new conflicting delete files that can apply to records matching ref(name=\"id\") == 1:"); + + // Validating from latest snapshot should succeed + table.refresh(); + long newSnapshotId = table.currentSnapshot().snapshotId(); + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(newSnapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwrite(functions.col("id").equalTo(1)); + } + + @TestTemplate + public void testOverwriteFilterSerializableIsolation3() throws Exception { + Table table = validationCatalog.loadTable(tableIdent); + long snapshotId = table.currentSnapshot().snapshotId(); + + // This should delete a data file + sql("DELETE FROM %s WHERE id='1'", tableName); + table.refresh(); + + // Validating from previous snapshot finds conflicts + List conflictingRecords = Lists.newArrayList(new SimpleRecord(1, "a")); + Dataset conflictingDf = spark.createDataFrame(conflictingRecords, SimpleRecord.class); + assertThatThrownBy( + () -> + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(snapshotId)) + .option( + SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwrite(functions.col("id").equalTo(1))) + .isInstanceOf(ValidationException.class) + .hasMessageStartingWith( + "Found conflicting deleted files that can contain records matching ref(name=\"id\") == 1:"); + + // Validating from latest snapshot should succeed + table.refresh(); + long newSnapshotId = table.currentSnapshot().snapshotId(); + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(newSnapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwrite(functions.col("id").equalTo(1)); + } + + @TestTemplate + public void testOverwriteFilterNoSnapshotIdValidation() throws Exception { + Table table = validationCatalog.loadTable(tableIdent); + + List records = Lists.newArrayList(new SimpleRecord(1, "a")); + spark.createDataFrame(records, SimpleRecord.class).writeTo(tableName).append(); + + // Validating from no snapshot id defaults to beginning snapshot id and finds conflicts + Dataset conflictingDf = spark.createDataFrame(records, SimpleRecord.class); + + assertThatThrownBy( + () -> + conflictingDf + .writeTo(tableName) + .option( + SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwrite(functions.col("id").equalTo(1))) + .isInstanceOf(ValidationException.class) + .hasMessageStartingWith( + "Found conflicting files that can contain records matching ref(name=\"id\") == 1:"); + + // Validating from latest snapshot should succeed + table.refresh(); + long newSnapshotId = table.currentSnapshot().snapshotId(); + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(newSnapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwrite(functions.col("id").equalTo(1)); + } + + @TestTemplate + public void testOverwriteFilterSnapshotIsolation() throws Exception { + List records = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(1, "b")); + spark.createDataFrame(records, SimpleRecord.class).coalesce(1).writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + long snapshotId = table.currentSnapshot().snapshotId(); + + // This should add a delete file + sql("DELETE FROM %s WHERE id='1' and data='b'", tableName); + table.refresh(); + + // Validating from previous snapshot finds conflicts + List conflictingRecords = Lists.newArrayList(new SimpleRecord(1, "a")); + Dataset conflictingDf = spark.createDataFrame(conflictingRecords, SimpleRecord.class); + assertThatThrownBy( + () -> + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(snapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SNAPSHOT.toString()) + .overwrite(functions.col("id").equalTo(1))) + .isInstanceOf(ValidationException.class) + .hasMessageStartingWith( + "Found new conflicting delete files that can apply to records matching ref(name=\"id\") == 1:"); + + // Validating from latest snapshot should succeed + table.refresh(); + long newSnapshotId = table.currentSnapshot().snapshotId(); + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(newSnapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SNAPSHOT.toString()) + .overwrite(functions.col("id").equalTo(1)); + } + + @TestTemplate + public void testOverwriteFilterSnapshotIsolation2() throws Exception { + Table table = validationCatalog.loadTable(tableIdent); + long snapshotId = table.currentSnapshot().snapshotId(); + + List records = Lists.newArrayList(new SimpleRecord(1, "a")); + spark.createDataFrame(records, SimpleRecord.class).writeTo(tableName).append(); + + // Validation should not fail due to conflicting data file in snapshot isolation mode + Dataset conflictingDf = spark.createDataFrame(records, SimpleRecord.class); + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(snapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SNAPSHOT.toString()) + .overwrite(functions.col("id").equalTo(1)); + } + + @TestTemplate + public void testOverwritePartitionSerializableIsolation() throws Exception { + Table table = validationCatalog.loadTable(tableIdent); + final long snapshotId = table.currentSnapshot().snapshotId(); + + List records = Lists.newArrayList(new SimpleRecord(1, "a")); + spark.createDataFrame(records, SimpleRecord.class).writeTo(tableName).append(); + + // Validating from previous snapshot finds conflicts + Dataset conflictingDf = spark.createDataFrame(records, SimpleRecord.class); + assertThatThrownBy( + () -> + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(snapshotId)) + .option( + SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwritePartitions()) + .isInstanceOf(ValidationException.class) + .hasMessageStartingWith( + "Found conflicting files that can contain records matching partitions [id=1]"); + + // Validating from latest snapshot should succeed + table.refresh(); + long newSnapshotId = table.currentSnapshot().snapshotId(); + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(newSnapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwritePartitions(); + } + + @TestTemplate + public void testOverwritePartitionSnapshotIsolation() throws Exception { + List records = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(1, "b")); + spark.createDataFrame(records, SimpleRecord.class).coalesce(1).writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + final long snapshotId = table.currentSnapshot().snapshotId(); + + // This should generate a delete file + sql("DELETE FROM %s WHERE data='a'", tableName); + + // Validating from previous snapshot finds conflicts + Dataset conflictingDf = spark.createDataFrame(records, SimpleRecord.class); + assertThatThrownBy( + () -> + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(snapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SNAPSHOT.toString()) + .overwritePartitions()) + .isInstanceOf(ValidationException.class) + .hasMessageStartingWith( + "Found new conflicting delete files that can apply to records matching [id=1]"); + + // Validating from latest snapshot should succeed + table.refresh(); + long newSnapshotId = table.currentSnapshot().snapshotId(); + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(newSnapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SNAPSHOT.toString()) + .overwritePartitions(); + } + + @TestTemplate + public void testOverwritePartitionSnapshotIsolation2() throws Exception { + Table table = validationCatalog.loadTable(tableIdent); + final long snapshotId = table.currentSnapshot().snapshotId(); + + // This should delete a data file + sql("DELETE FROM %s WHERE id='1'", tableName); + + // Validating from previous snapshot finds conflicts + List records = Lists.newArrayList(new SimpleRecord(1, "a")); + spark.createDataFrame(records, SimpleRecord.class).coalesce(1).writeTo(tableName).append(); + Dataset conflictingDf = spark.createDataFrame(records, SimpleRecord.class); + + assertThatThrownBy( + () -> + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(snapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SNAPSHOT.toString()) + .overwritePartitions()) + .isInstanceOf(ValidationException.class) + .hasMessageStartingWith( + "Found conflicting deleted files that can apply to records matching [id=1]"); + + // Validating from latest snapshot should succeed + table.refresh(); + long newSnapshotId = table.currentSnapshot().snapshotId(); + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(newSnapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SNAPSHOT.toString()) + .overwritePartitions(); + } + + @TestTemplate + public void testOverwritePartitionSnapshotIsolation3() throws Exception { + Table table = validationCatalog.loadTable(tableIdent); + final long snapshotId = table.currentSnapshot().snapshotId(); + + List records = Lists.newArrayList(new SimpleRecord(1, "a")); + spark.createDataFrame(records, SimpleRecord.class).writeTo(tableName).append(); + + // Validation should not find conflicting data file in snapshot isolation mode + Dataset conflictingDf = spark.createDataFrame(records, SimpleRecord.class); + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(snapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SNAPSHOT.toString()) + .overwritePartitions(); + } + + @TestTemplate + public void testOverwritePartitionNoSnapshotIdValidation() throws Exception { + Table table = validationCatalog.loadTable(tableIdent); + + List records = Lists.newArrayList(new SimpleRecord(1, "a")); + spark.createDataFrame(records, SimpleRecord.class).writeTo(tableName).append(); + + // Validating from null snapshot is equivalent to validating from beginning + Dataset conflictingDf = spark.createDataFrame(records, SimpleRecord.class); + assertThatThrownBy( + () -> + conflictingDf + .writeTo(tableName) + .option( + SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwritePartitions()) + .isInstanceOf(ValidationException.class) + .hasMessageStartingWith( + "Found conflicting files that can contain records matching partitions [id=1]"); + + // Validating from latest snapshot should succeed + table.refresh(); + long snapshotId = table.currentSnapshot().snapshotId(); + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(snapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwritePartitions(); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteDelete.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteDelete.java new file mode 100644 index 000000000000..e5d44d97de1e --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteDelete.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.TableProperties.DELETE_ISOLATION_LEVEL; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assumptions.assumeThat; + +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.iceberg.AppendFiles; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.internal.SQLConf; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestCopyOnWriteDelete extends TestDelete { + + @Override + protected Map extraTableProperties() { + return ImmutableMap.of( + TableProperties.DELETE_MODE, RowLevelOperationMode.COPY_ON_WRITE.modeName()); + } + + @TestTemplate + public synchronized void testDeleteWithConcurrentTableRefresh() throws Exception { + // this test can only be run with Hive tables as it requires a reliable lock + // also, the table cache must be enabled so that the same table instance can be reused + assumeThat(catalogName).isEqualToIgnoringCase("testhive"); + + createAndInitUnpartitionedTable(); + createOrReplaceView("deleted_id", Collections.singletonList(1), Encoders.INT()); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, DELETE_ISOLATION_LEVEL, "snapshot"); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + + ExecutorService executorService = + MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) Executors.newFixedThreadPool(2)); + + AtomicInteger barrier = new AtomicInteger(0); + AtomicBoolean shouldAppend = new AtomicBoolean(true); + + // delete thread + Future deleteFuture = + executorService.submit( + () -> { + for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) { + while (barrier.get() < numOperations * 2) { + sleep(10); + } + + sql("DELETE FROM %s WHERE id IN (SELECT * FROM deleted_id)", commitTarget()); + + barrier.incrementAndGet(); + } + }); + + // append thread + Future appendFuture = + executorService.submit( + () -> { + GenericRecord record = GenericRecord.create(SnapshotUtil.schemaFor(table, branch)); + record.set(0, 1); // id + record.set(1, "hr"); // dep + + for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) { + while (shouldAppend.get() && barrier.get() < numOperations * 2) { + sleep(10); + } + + if (!shouldAppend.get()) { + return; + } + + for (int numAppends = 0; numAppends < 5; numAppends++) { + DataFile dataFile = writeDataFile(table, ImmutableList.of(record)); + AppendFiles appendFiles = table.newFastAppend().appendFile(dataFile); + if (branch != null) { + appendFiles.toBranch(branch); + } + + appendFiles.commit(); + sleep(10); + } + + barrier.incrementAndGet(); + } + }); + + try { + assertThatThrownBy(deleteFuture::get) + .isInstanceOf(ExecutionException.class) + .cause() + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("the table has been concurrently modified"); + } finally { + shouldAppend.set(false); + appendFuture.cancel(true); + } + + executorService.shutdown(); + assertThat(executorService.awaitTermination(2, TimeUnit.MINUTES)).as("Timeout").isTrue(); + } + + @TestTemplate + public void testRuntimeFilteringWithPreservedDataGrouping() throws NoSuchTableException { + createAndInitPartitionedTable(); + + append(tableName, new Employee(1, "hr"), new Employee(3, "hr")); + createBranchIfNeeded(); + append(new Employee(1, "hardware"), new Employee(2, "hardware")); + + Map sqlConf = + ImmutableMap.of( + SQLConf.V2_BUCKETING_ENABLED().key(), + "true", + SparkSQLProperties.PRESERVE_DATA_GROUPING, + "true"); + + withSQLConf(sqlConf, () -> sql("DELETE FROM %s WHERE id = 2", commitTarget())); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).as("Should have 3 snapshots").hasSize(3); + + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + validateCopyOnWrite(currentSnapshot, "1", "1", "1"); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hardware"), row(1, "hr"), row(3, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteMerge.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteMerge.java new file mode 100644 index 000000000000..1fb1238de635 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteMerge.java @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.TableProperties.MERGE_ISOLATION_LEVEL; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assumptions.assumeThat; + +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.internal.SQLConf; +import org.junit.jupiter.api.TestTemplate; + +public class TestCopyOnWriteMerge extends TestMerge { + + @Override + protected Map extraTableProperties() { + return ImmutableMap.of( + TableProperties.MERGE_MODE, RowLevelOperationMode.COPY_ON_WRITE.modeName()); + } + + @TestTemplate + public synchronized void testMergeWithConcurrentTableRefresh() throws Exception { + // this test can only be run with Hive tables as it requires a reliable lock + // also, the table cache must be enabled so that the same table instance can be reused + assumeThat(catalogName).isEqualToIgnoringCase("testhive"); + + createAndInitTable("id INT, dep STRING"); + createOrReplaceView("source", Collections.singletonList(1), Encoders.INT()); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, MERGE_ISOLATION_LEVEL, "snapshot"); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + + ExecutorService executorService = + MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) Executors.newFixedThreadPool(2)); + + AtomicInteger barrier = new AtomicInteger(0); + AtomicBoolean shouldAppend = new AtomicBoolean(true); + + // merge thread + Future mergeFuture = + executorService.submit( + () -> { + for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) { + while (barrier.get() < numOperations * 2) { + sleep(10); + } + + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.value " + + "WHEN MATCHED THEN " + + " UPDATE SET dep = 'x'", + tableName); + + barrier.incrementAndGet(); + } + }); + + // append thread + Future appendFuture = + executorService.submit( + () -> { + GenericRecord record = GenericRecord.create(table.schema()); + record.set(0, 1); // id + record.set(1, "hr"); // dep + + for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) { + while (shouldAppend.get() && barrier.get() < numOperations * 2) { + sleep(10); + } + + if (!shouldAppend.get()) { + return; + } + + for (int numAppends = 0; numAppends < 5; numAppends++) { + DataFile dataFile = writeDataFile(table, ImmutableList.of(record)); + table.newFastAppend().appendFile(dataFile).commit(); + sleep(10); + } + + barrier.incrementAndGet(); + } + }); + + try { + assertThatThrownBy(mergeFuture::get) + .isInstanceOf(ExecutionException.class) + .cause() + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("the table has been concurrently modified"); + } finally { + shouldAppend.set(false); + appendFuture.cancel(true); + } + + executorService.shutdown(); + assertThat(executorService.awaitTermination(2, TimeUnit.MINUTES)).as("Timeout").isTrue(); + } + + @TestTemplate + public void testRuntimeFilteringWithReportedPartitioning() { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + append(tableName, "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 3, \"dep\": \"hr\" }"); + createBranchIfNeeded(); + append( + commitTarget(), + "{ \"id\": 1, \"dep\": \"hardware\" }\n" + "{ \"id\": 2, \"dep\": \"hardware\" }"); + + createOrReplaceView("source", Collections.singletonList(2), Encoders.INT()); + + Map sqlConf = + ImmutableMap.of( + SQLConf.V2_BUCKETING_ENABLED().key(), + "true", + SparkSQLProperties.PRESERVE_DATA_GROUPING, + "true"); + + withSQLConf( + sqlConf, + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.value " + + "WHEN MATCHED THEN " + + " UPDATE SET id = -1", + commitTarget())); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).as("Should have 3 snapshots").hasSize(3); + + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + validateCopyOnWrite(currentSnapshot, "1", "1", "1"); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hardware"), row(1, "hardware"), row(1, "hr"), row(3, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteUpdate.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteUpdate.java new file mode 100644 index 000000000000..5bc7b22f9a09 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteUpdate.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.TableProperties.UPDATE_ISOLATION_LEVEL; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assumptions.assumeThat; + +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.iceberg.AppendFiles; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.sql.internal.SQLConf; +import org.junit.jupiter.api.TestTemplate; + +public class TestCopyOnWriteUpdate extends TestUpdate { + + @Override + protected Map extraTableProperties() { + return ImmutableMap.of( + TableProperties.UPDATE_MODE, RowLevelOperationMode.COPY_ON_WRITE.modeName()); + } + + @TestTemplate + public synchronized void testUpdateWithConcurrentTableRefresh() throws Exception { + // this test can only be run with Hive tables as it requires a reliable lock + // also, the table cache must be enabled so that the same table instance can be reused + assumeThat(catalogName).isEqualToIgnoringCase("testhive"); + + createAndInitTable("id INT, dep STRING"); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, UPDATE_ISOLATION_LEVEL, "snapshot"); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + + ExecutorService executorService = + MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) Executors.newFixedThreadPool(2)); + + AtomicInteger barrier = new AtomicInteger(0); + AtomicBoolean shouldAppend = new AtomicBoolean(true); + + // update thread + Future updateFuture = + executorService.submit( + () -> { + for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) { + while (barrier.get() < numOperations * 2) { + sleep(10); + } + + sql("UPDATE %s SET id = -1 WHERE id = 1", commitTarget()); + + barrier.incrementAndGet(); + } + }); + + // append thread + Future appendFuture = + executorService.submit( + () -> { + GenericRecord record = GenericRecord.create(SnapshotUtil.schemaFor(table, branch)); + record.set(0, 1); // id + record.set(1, "hr"); // dep + + for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) { + while (shouldAppend.get() && barrier.get() < numOperations * 2) { + sleep(10); + } + + if (!shouldAppend.get()) { + return; + } + + for (int numAppends = 0; numAppends < 5; numAppends++) { + DataFile dataFile = writeDataFile(table, ImmutableList.of(record)); + AppendFiles appendFiles = table.newFastAppend().appendFile(dataFile); + if (branch != null) { + appendFiles.toBranch(branch); + } + + appendFiles.commit(); + sleep(10); + } + + barrier.incrementAndGet(); + } + }); + + try { + assertThatThrownBy(updateFuture::get) + .isInstanceOf(ExecutionException.class) + .cause() + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("the table has been concurrently modified"); + } finally { + shouldAppend.set(false); + appendFuture.cancel(true); + } + + executorService.shutdown(); + assertThat(executorService.awaitTermination(2, TimeUnit.MINUTES)).as("Timeout").isTrue(); + } + + @TestTemplate + public void testRuntimeFilteringWithReportedPartitioning() { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + append(tableName, "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 3, \"dep\": \"hr\" }"); + createBranchIfNeeded(); + append( + commitTarget(), + "{ \"id\": 1, \"dep\": \"hardware\" }\n" + "{ \"id\": 2, \"dep\": \"hardware\" }"); + + Map sqlConf = + ImmutableMap.of( + SQLConf.V2_BUCKETING_ENABLED().key(), + "true", + SparkSQLProperties.PRESERVE_DATA_GROUPING, + "true"); + + withSQLConf(sqlConf, () -> sql("UPDATE %s SET id = -1 WHERE id = 2", commitTarget())); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).as("Should have 3 snapshots").hasSize(3); + + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + validateCopyOnWrite(currentSnapshot, "1", "1", "1"); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hardware"), row(1, "hardware"), row(1, "hr"), row(3, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCreateChangelogViewProcedure.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCreateChangelogViewProcedure.java new file mode 100644 index 000000000000..3fd760c67c4a --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCreateChangelogViewProcedure.java @@ -0,0 +1,453 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.List; +import org.apache.iceberg.ChangelogOperation; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkReadOptions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestCreateChangelogViewProcedure extends ExtensionsTestBase { + private static final String DELETE = ChangelogOperation.DELETE.name(); + private static final String INSERT = ChangelogOperation.INSERT.name(); + private static final String UPDATE_BEFORE = ChangelogOperation.UPDATE_BEFORE.name(); + private static final String UPDATE_AFTER = ChangelogOperation.UPDATE_AFTER.name(); + + @AfterEach + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + public void createTableWithTwoColumns() { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + sql("ALTER TABLE %s ADD PARTITION FIELD data", tableName); + } + + private void createTableWithThreeColumns() { + sql("CREATE TABLE %s (id INT, data STRING, age INT) USING iceberg", tableName); + sql("ALTER TABLE %s ADD PARTITION FIELD id", tableName); + } + + private void createTableWithIdentifierField() { + sql("CREATE TABLE %s (id INT NOT NULL, data STRING) USING iceberg", tableName); + sql("ALTER TABLE %s SET IDENTIFIER FIELDS id", tableName); + } + + @TestTemplate + public void testCustomizedViewName() { + createTableWithTwoColumns(); + sql("INSERT INTO %s VALUES (1, 'a')", tableName); + sql("INSERT INTO %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + Snapshot snap1 = table.currentSnapshot(); + + sql("INSERT OVERWRITE %s VALUES (-2, 'b')", tableName); + + table.refresh(); + + Snapshot snap2 = table.currentSnapshot(); + + sql( + "CALL %s.system.create_changelog_view(" + + "table => '%s'," + + "options => map('%s','%s','%s','%s')," + + "changelog_view => '%s')", + catalogName, + tableName, + SparkReadOptions.START_SNAPSHOT_ID, + snap1.snapshotId(), + SparkReadOptions.END_SNAPSHOT_ID, + snap2.snapshotId(), + "cdc_view"); + + long rowCount = sql("select * from %s", "cdc_view").stream().count(); + assertThat(rowCount).isEqualTo(2); + } + + @TestTemplate + public void testNoSnapshotIdInput() { + createTableWithTwoColumns(); + sql("INSERT INTO %s VALUES (1, 'a')", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Snapshot snap0 = table.currentSnapshot(); + + sql("INSERT INTO %s VALUES (2, 'b')", tableName); + table.refresh(); + Snapshot snap1 = table.currentSnapshot(); + + sql("INSERT OVERWRITE %s VALUES (-2, 'b')", tableName); + table.refresh(); + Snapshot snap2 = table.currentSnapshot(); + + List returns = + sql( + "CALL %s.system.create_changelog_view(" + "table => '%s')", + catalogName, tableName, "cdc_view"); + + String viewName = (String) returns.get(0)[0]; + assertEquals( + "Rows should match", + ImmutableList.of( + row(1, "a", INSERT, 0, snap0.snapshotId()), + row(2, "b", INSERT, 1, snap1.snapshotId()), + row(-2, "b", INSERT, 2, snap2.snapshotId()), + row(2, "b", DELETE, 2, snap2.snapshotId())), + sql("select * from %s order by _change_ordinal, id", viewName)); + } + + @TestTemplate + public void testTimestampsBasedQuery() { + createTableWithTwoColumns(); + long beginning = System.currentTimeMillis(); + + sql("INSERT INTO %s VALUES (1, 'a')", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Snapshot snap0 = table.currentSnapshot(); + long afterFirstInsert = waitUntilAfter(snap0.timestampMillis()); + + sql("INSERT INTO %s VALUES (2, 'b')", tableName); + table.refresh(); + Snapshot snap1 = table.currentSnapshot(); + + sql("INSERT OVERWRITE %s VALUES (-2, 'b')", tableName); + table.refresh(); + Snapshot snap2 = table.currentSnapshot(); + + long afterInsertOverwrite = waitUntilAfter(snap2.timestampMillis()); + List returns = + sql( + "CALL %s.system.create_changelog_view(table => '%s', " + + "options => map('%s', '%s','%s', '%s'))", + catalogName, + tableName, + SparkReadOptions.START_TIMESTAMP, + beginning, + SparkReadOptions.END_TIMESTAMP, + afterInsertOverwrite); + + assertEquals( + "Rows should match", + ImmutableList.of( + row(1, "a", INSERT, 0, snap0.snapshotId()), + row(2, "b", INSERT, 1, snap1.snapshotId()), + row(-2, "b", INSERT, 2, snap2.snapshotId()), + row(2, "b", DELETE, 2, snap2.snapshotId())), + sql("select * from %s order by _change_ordinal, id", returns.get(0)[0])); + + // query the timestamps starting from the second insert + returns = + sql( + "CALL %s.system.create_changelog_view(table => '%s', " + + "options => map('%s', '%s', '%s', '%s'))", + catalogName, + tableName, + SparkReadOptions.START_TIMESTAMP, + afterFirstInsert, + SparkReadOptions.END_TIMESTAMP, + afterInsertOverwrite); + + assertEquals( + "Rows should match", + ImmutableList.of( + row(2, "b", INSERT, 0, snap1.snapshotId()), + row(-2, "b", INSERT, 1, snap2.snapshotId()), + row(2, "b", DELETE, 1, snap2.snapshotId())), + sql("select * from %s order by _change_ordinal, id", returns.get(0)[0])); + } + + @TestTemplate + public void testUpdate() { + createTableWithTwoColumns(); + sql("ALTER TABLE %s DROP PARTITION FIELD data", tableName); + sql("ALTER TABLE %s ADD PARTITION FIELD id", tableName); + + sql("INSERT INTO %s VALUES (1, 'a'), (2, 'b')", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Snapshot snap1 = table.currentSnapshot(); + + sql("INSERT OVERWRITE %s VALUES (3, 'c'), (2, 'd')", tableName); + table.refresh(); + Snapshot snap2 = table.currentSnapshot(); + + List returns = + sql( + "CALL %s.system.create_changelog_view(table => '%s', identifier_columns => array('id'))", + catalogName, tableName); + + String viewName = (String) returns.get(0)[0]; + assertEquals( + "Rows should match", + ImmutableList.of( + row(1, "a", INSERT, 0, snap1.snapshotId()), + row(2, "b", INSERT, 0, snap1.snapshotId()), + row(2, "b", UPDATE_BEFORE, 1, snap2.snapshotId()), + row(2, "d", UPDATE_AFTER, 1, snap2.snapshotId()), + row(3, "c", INSERT, 1, snap2.snapshotId())), + sql("select * from %s order by _change_ordinal, id, data", viewName)); + } + + @TestTemplate + public void testUpdateWithIdentifierField() { + createTableWithIdentifierField(); + + sql("INSERT INTO %s VALUES (2, 'b')", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Snapshot snap1 = table.currentSnapshot(); + + sql("INSERT OVERWRITE %s VALUES (3, 'c'), (2, 'd')", tableName); + table.refresh(); + Snapshot snap2 = table.currentSnapshot(); + + List returns = + sql( + "CALL %s.system.create_changelog_view(table => '%s', compute_updates => true)", + catalogName, tableName); + + String viewName = (String) returns.get(0)[0]; + assertEquals( + "Rows should match", + ImmutableList.of( + row(2, "b", INSERT, 0, snap1.snapshotId()), + row(2, "b", UPDATE_BEFORE, 1, snap2.snapshotId()), + row(2, "d", UPDATE_AFTER, 1, snap2.snapshotId()), + row(3, "c", INSERT, 1, snap2.snapshotId())), + sql("select * from %s order by _change_ordinal, id, data", viewName)); + } + + @TestTemplate + public void testUpdateWithFilter() { + createTableWithTwoColumns(); + sql("ALTER TABLE %s DROP PARTITION FIELD data", tableName); + sql("ALTER TABLE %s ADD PARTITION FIELD id", tableName); + + sql("INSERT INTO %s VALUES (1, 'a'), (2, 'b')", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Snapshot snap1 = table.currentSnapshot(); + + sql("INSERT OVERWRITE %s VALUES (3, 'c'), (2, 'd')", tableName); + table.refresh(); + Snapshot snap2 = table.currentSnapshot(); + + List returns = + sql( + "CALL %s.system.create_changelog_view(table => '%s', identifier_columns => array('id'))", + catalogName, tableName); + + String viewName = (String) returns.get(0)[0]; + assertEquals( + "Rows should match", + ImmutableList.of( + row(1, "a", INSERT, 0, snap1.snapshotId()), + row(2, "b", INSERT, 0, snap1.snapshotId()), + row(2, "b", UPDATE_BEFORE, 1, snap2.snapshotId()), + row(2, "d", UPDATE_AFTER, 1, snap2.snapshotId())), + // the predicate on partition columns will filter out the insert of (3, 'c') at the planning + // phase + sql("select * from %s where id != 3 order by _change_ordinal, id, data", viewName)); + } + + @TestTemplate + public void testUpdateWithMultipleIdentifierColumns() { + createTableWithThreeColumns(); + + sql("INSERT INTO %s VALUES (1, 'a', 12), (2, 'b', 11)", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Snapshot snap1 = table.currentSnapshot(); + + sql("INSERT OVERWRITE %s VALUES (3, 'c', 13), (2, 'd', 11), (2, 'e', 12)", tableName); + table.refresh(); + Snapshot snap2 = table.currentSnapshot(); + + List returns = + sql( + "CALL %s.system.create_changelog_view(" + + "identifier_columns => array('id','age')," + + "table => '%s')", + catalogName, tableName); + + String viewName = (String) returns.get(0)[0]; + assertEquals( + "Rows should match", + ImmutableList.of( + row(1, "a", 12, INSERT, 0, snap1.snapshotId()), + row(2, "b", 11, INSERT, 0, snap1.snapshotId()), + row(2, "b", 11, UPDATE_BEFORE, 1, snap2.snapshotId()), + row(2, "d", 11, UPDATE_AFTER, 1, snap2.snapshotId()), + row(2, "e", 12, INSERT, 1, snap2.snapshotId()), + row(3, "c", 13, INSERT, 1, snap2.snapshotId())), + sql("select * from %s order by _change_ordinal, id, data", viewName)); + } + + @TestTemplate + public void testRemoveCarryOvers() { + createTableWithThreeColumns(); + + sql("INSERT INTO %s VALUES (1, 'a', 12), (2, 'b', 11), (2, 'e', 12)", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Snapshot snap1 = table.currentSnapshot(); + + // carry-over row (2, 'e', 12) + sql("INSERT OVERWRITE %s VALUES (3, 'c', 13), (2, 'd', 11), (2, 'e', 12)", tableName); + table.refresh(); + Snapshot snap2 = table.currentSnapshot(); + + List returns = + sql( + "CALL %s.system.create_changelog_view(" + + "identifier_columns => array('id','age'), " + + "table => '%s')", + catalogName, tableName); + + String viewName = (String) returns.get(0)[0]; + // the carry-over rows (2, 'e', 12, 'DELETE', 1), (2, 'e', 12, 'INSERT', 1) are removed + assertEquals( + "Rows should match", + ImmutableList.of( + row(1, "a", 12, INSERT, 0, snap1.snapshotId()), + row(2, "b", 11, INSERT, 0, snap1.snapshotId()), + row(2, "e", 12, INSERT, 0, snap1.snapshotId()), + row(2, "b", 11, UPDATE_BEFORE, 1, snap2.snapshotId()), + row(2, "d", 11, UPDATE_AFTER, 1, snap2.snapshotId()), + row(3, "c", 13, INSERT, 1, snap2.snapshotId())), + sql("select * from %s order by _change_ordinal, id, data", viewName)); + } + + @TestTemplate + public void testRemoveCarryOversWithoutUpdatedRows() { + createTableWithThreeColumns(); + + sql("INSERT INTO %s VALUES (1, 'a', 12), (2, 'b', 11), (2, 'e', 12)", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Snapshot snap1 = table.currentSnapshot(); + + // carry-over row (2, 'e', 12) + sql("INSERT OVERWRITE %s VALUES (3, 'c', 13), (2, 'd', 11), (2, 'e', 12)", tableName); + table.refresh(); + Snapshot snap2 = table.currentSnapshot(); + + List returns = + sql("CALL %s.system.create_changelog_view(table => '%s')", catalogName, tableName); + + String viewName = (String) returns.get(0)[0]; + + // the carry-over rows (2, 'e', 12, 'DELETE', 1), (2, 'e', 12, 'INSERT', 1) are removed, even + // though update-row is not computed + assertEquals( + "Rows should match", + ImmutableList.of( + row(1, "a", 12, INSERT, 0, snap1.snapshotId()), + row(2, "b", 11, INSERT, 0, snap1.snapshotId()), + row(2, "e", 12, INSERT, 0, snap1.snapshotId()), + row(2, "b", 11, DELETE, 1, snap2.snapshotId()), + row(2, "d", 11, INSERT, 1, snap2.snapshotId()), + row(3, "c", 13, INSERT, 1, snap2.snapshotId())), + sql("select * from %s order by _change_ordinal, id, data", viewName)); + } + + @TestTemplate + public void testNetChangesWithRemoveCarryOvers() { + // partitioned by id + createTableWithThreeColumns(); + + // insert rows: (1, 'a', 12) (2, 'b', 11) (2, 'e', 12) + sql("INSERT INTO %s VALUES (1, 'a', 12), (2, 'b', 11), (2, 'e', 12)", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Snapshot snap1 = table.currentSnapshot(); + + // delete rows: (2, 'b', 11) (2, 'e', 12) + // insert rows: (3, 'c', 13) (2, 'd', 11) (2, 'e', 12) + sql("INSERT OVERWRITE %s VALUES (3, 'c', 13), (2, 'd', 11), (2, 'e', 12)", tableName); + table.refresh(); + Snapshot snap2 = table.currentSnapshot(); + + // delete rows: (2, 'd', 11) (2, 'e', 12) (3, 'c', 13) + // insert rows: (3, 'c', 15) (2, 'e', 12) + sql("INSERT OVERWRITE %s VALUES (3, 'c', 15), (2, 'e', 12)", tableName); + table.refresh(); + Snapshot snap3 = table.currentSnapshot(); + + // test with all snapshots + List returns = + sql( + "CALL %s.system.create_changelog_view(table => '%s', net_changes => true)", + catalogName, tableName); + + String viewName = (String) returns.get(0)[0]; + + assertEquals( + "Rows should match", + ImmutableList.of( + row(1, "a", 12, INSERT, 0, snap1.snapshotId()), + row(3, "c", 15, INSERT, 2, snap3.snapshotId()), + row(2, "e", 12, INSERT, 2, snap3.snapshotId())), + sql("select * from %s order by _change_ordinal, data", viewName)); + + // test with snap2 and snap3 + sql( + "CALL %s.system.create_changelog_view(table => '%s', " + + "options => map('start-snapshot-id','%s'), " + + "net_changes => true)", + catalogName, tableName, snap1.snapshotId()); + + assertEquals( + "Rows should match", + ImmutableList.of( + row(2, "b", 11, DELETE, 0, snap2.snapshotId()), + row(3, "c", 15, INSERT, 1, snap3.snapshotId())), + sql("select * from %s order by _change_ordinal, data", viewName)); + } + + @TestTemplate + public void testNetChangesWithComputeUpdates() { + createTableWithTwoColumns(); + assertThatThrownBy( + () -> + sql( + "CALL %s.system.create_changelog_view(table => '%s', identifier_columns => array('id'), net_changes => true)", + catalogName, tableName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Not support net changes with update images"); + } + + @TestTemplate + public void testUpdateWithInComparableType() { + sql( + "CREATE TABLE %s (id INT NOT NULL, data MAP, age INT) USING iceberg", + tableName); + + assertThatThrownBy( + () -> + sql("CALL %s.system.create_changelog_view(table => '%s')", catalogName, tableName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining( + "Identifier field is required as table contains unorderable columns: [data]"); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java new file mode 100644 index 000000000000..42eb2af774e9 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java @@ -0,0 +1,1465 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.DataOperations.DELETE; +import static org.apache.iceberg.RowLevelOperationMode.COPY_ON_WRITE; +import static org.apache.iceberg.SnapshotSummary.ADD_POS_DELETE_FILES_PROP; +import static org.apache.iceberg.TableProperties.DELETE_DISTRIBUTION_MODE; +import static org.apache.iceberg.TableProperties.DELETE_ISOLATION_LEVEL; +import static org.apache.iceberg.TableProperties.DELETE_MODE; +import static org.apache.iceberg.TableProperties.DELETE_MODE_DEFAULT; +import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES; +import static org.apache.iceberg.TableProperties.SPLIT_OPEN_FILE_COST; +import static org.apache.iceberg.TableProperties.SPLIT_SIZE; +import static org.apache.spark.sql.functions.lit; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assumptions.assumeThat; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import org.apache.iceberg.AppendFiles; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DistributionMode; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotRef; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.apache.spark.sql.catalyst.plans.logical.DeleteFromTableWithFilters; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.plans.logical.RowLevelWrite; +import org.apache.spark.sql.execution.SparkPlan; +import org.apache.spark.sql.execution.datasources.v2.OptimizeMetadataOnlyDeleteFromTable; +import org.apache.spark.sql.internal.SQLConf; +import org.awaitility.Awaitility; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public abstract class TestDelete extends SparkRowLevelOperationsTestBase { + + @BeforeAll + public static void setupSparkConf() { + spark.conf().set("spark.sql.shuffle.partitions", "4"); + } + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS deleted_id"); + sql("DROP TABLE IF EXISTS deleted_dep"); + sql("DROP TABLE IF EXISTS parquet_table"); + } + + @TestTemplate + public void testDeleteWithVectorizedReads() throws NoSuchTableException { + assumeThat(supportsVectorization()).isTrue(); + + createAndInitPartitionedTable(); + + append(tableName, new Employee(1, "hr"), new Employee(2, "hr")); + append(tableName, new Employee(3, "hardware"), new Employee(4, "hardware")); + + createBranchIfNeeded(); + + SparkPlan plan = executeAndKeepPlan("DELETE FROM %s WHERE id = 2", commitTarget()); + assertAllBatchScansVectorized(plan); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).as("Should have 3 snapshots").hasSize(3); + + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + if (mode(table) == COPY_ON_WRITE) { + validateCopyOnWrite(currentSnapshot, "1", "1", "1"); + } else { + validateMergeOnRead(currentSnapshot, "1", "1", null); + } + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(3, "hardware"), row(4, "hardware")), + sql("SELECT * FROM %s ORDER BY id ASC", selectTarget())); + } + + @TestTemplate + public void testCoalesceDelete() throws Exception { + createAndInitUnpartitionedTable(); + + Employee[] employees = new Employee[100]; + for (int index = 0; index < 100; index++) { + employees[index] = new Employee(index, "hr"); + } + append(tableName, employees); + append(tableName, employees); + append(tableName, employees); + append(tableName, employees); + + // set the open file cost large enough to produce a separate scan task per file + // use range distribution to trigger a shuffle + Map tableProps = + ImmutableMap.of( + SPLIT_OPEN_FILE_COST, + String.valueOf(Integer.MAX_VALUE), + DELETE_DISTRIBUTION_MODE, + DistributionMode.RANGE.modeName()); + sql("ALTER TABLE %s SET TBLPROPERTIES (%s)", tableName, tablePropsAsString(tableProps)); + + createBranchIfNeeded(); + + // enable AQE and set the advisory partition size big enough to trigger combining + // set the number of shuffle partitions to 200 to distribute the work across reducers + // set the advisory partition size for shuffles small enough to ensure writes override it + withSQLConf( + ImmutableMap.of( + SQLConf.SHUFFLE_PARTITIONS().key(), + "200", + SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), + "true", + SQLConf.COALESCE_PARTITIONS_ENABLED().key(), + "true", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES().key(), + "100", + SparkSQLProperties.ADVISORY_PARTITION_SIZE, + String.valueOf(256 * 1024 * 1024)), + () -> { + SparkPlan plan = + executeAndKeepPlan("DELETE FROM %s WHERE mod(id, 2) = 0", commitTarget()); + assertThat(plan.toString()).contains("REBALANCE_PARTITIONS_BY_COL"); + }); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot snapshot = SnapshotUtil.latestSnapshot(table, branch); + + if (mode(table) == COPY_ON_WRITE) { + // CoW DELETE requests the remaining records to be range distributed by `_file`, `_pos` + // every task has data for each of 200 reducers + // AQE detects that all shuffle blocks are small and processes them in 1 task + // otherwise, there would be 200 tasks writing to the table + validateProperty(snapshot, SnapshotSummary.ADDED_FILES_PROP, "1"); + } else { + // MoR DELETE requests the deleted records to be range distributed by partition and `_file` + // each task contains only 1 file and therefore writes only 1 shuffle block + // that means 4 shuffle blocks are distributed among 200 reducers + // AQE detects that all 4 shuffle blocks are small and processes them in 1 task + // otherwise, there would be 4 tasks processing 1 shuffle block each + validateProperty(snapshot, SnapshotSummary.ADDED_DELETE_FILES_PROP, "1"); + } + + assertThat(scalarSql("SELECT COUNT(*) FROM %s", commitTarget())) + .as("Row count must match") + .isEqualTo(200L); + } + + @TestTemplate + public void testSkewDelete() throws Exception { + createAndInitPartitionedTable(); + + Employee[] employees = new Employee[100]; + for (int index = 0; index < 100; index++) { + employees[index] = new Employee(index, "hr"); + } + append(tableName, employees); + append(tableName, employees); + append(tableName, employees); + append(tableName, employees); + + // set the open file cost large enough to produce a separate scan task per file + // use hash distribution to trigger a shuffle + Map tableProps = + ImmutableMap.of( + SPLIT_OPEN_FILE_COST, + String.valueOf(Integer.MAX_VALUE), + DELETE_DISTRIBUTION_MODE, + DistributionMode.HASH.modeName()); + sql("ALTER TABLE %s SET TBLPROPERTIES (%s)", tableName, tablePropsAsString(tableProps)); + + createBranchIfNeeded(); + + // enable AQE and set the advisory partition size small enough to trigger a split + // set the number of shuffle partitions to 2 to only have 2 reducers + // set the advisory partition size for shuffles big enough to ensure writes override it + withSQLConf( + ImmutableMap.of( + SQLConf.SHUFFLE_PARTITIONS().key(), + "2", + SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), + "true", + SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED().key(), + "true", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES().key(), + "256MB", + SparkSQLProperties.ADVISORY_PARTITION_SIZE, + "100"), + () -> { + SparkPlan plan = + executeAndKeepPlan("DELETE FROM %s WHERE mod(id, 2) = 0", commitTarget()); + assertThat(plan.toString()).contains("REBALANCE_PARTITIONS_BY_COL"); + }); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot snapshot = SnapshotUtil.latestSnapshot(table, branch); + + if (mode(table) == COPY_ON_WRITE) { + // CoW DELETE requests the remaining records to be clustered by `_file` + // each task contains only 1 file and therefore writes only 1 shuffle block + // that means 4 shuffle blocks are distributed among 2 reducers + // AQE detects that all shuffle blocks are big and processes them in 4 independent tasks + // otherwise, there would be 2 tasks processing 2 shuffle blocks each + validateProperty(snapshot, SnapshotSummary.ADDED_FILES_PROP, "4"); + } else { + // MoR DELETE requests the deleted records to be clustered by `_spec_id` and `_partition` + // all tasks belong to the same partition and therefore write only 1 shuffle block per task + // that means there are 4 shuffle blocks, all assigned to the same reducer + // AQE detects that all 4 shuffle blocks are big and processes them in 4 separate tasks + // otherwise, there would be 1 task processing 4 shuffle blocks + validateProperty(snapshot, SnapshotSummary.ADDED_DELETE_FILES_PROP, "4"); + } + + assertThat(scalarSql("SELECT COUNT(*) FROM %s", commitTarget())) + .as("Row count must match") + .isEqualTo(200L); + } + + @TestTemplate + public void testDeleteWithoutScanningTable() throws Exception { + createAndInitPartitionedTable(); + + append(tableName, new Employee(1, "hr"), new Employee(3, "hr")); + createBranchIfNeeded(); + append(new Employee(1, "hardware"), new Employee(2, "hardware")); + + Table table = validationCatalog.loadTable(tableIdent); + + List manifestLocations = + SnapshotUtil.latestSnapshot(table, branch).allManifests(table.io()).stream() + .map(ManifestFile::path) + .collect(Collectors.toList()); + + withUnavailableLocations( + manifestLocations, + () -> { + LogicalPlan parsed = parsePlan("DELETE FROM %s WHERE dep = 'hr'", commitTarget()); + + LogicalPlan analyzed = spark.sessionState().analyzer().execute(parsed); + assertThat(analyzed).isInstanceOf(RowLevelWrite.class); + + LogicalPlan optimized = OptimizeMetadataOnlyDeleteFromTable.apply(analyzed); + assertThat(optimized).isInstanceOf(DeleteFromTableWithFilters.class); + }); + + sql("DELETE FROM %s WHERE dep = 'hr'", commitTarget()); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hardware"), row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testDeleteFileThenMetadataDelete() throws Exception { + assumeThat(fileFormat) + .as("Avro does not support metadata delete") + .isNotEqualTo(FileFormat.AVRO); + createAndInitUnpartitionedTable(); + createBranchIfNeeded(); + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", commitTarget()); + + // MOR mode: writes a delete file as null cannot be deleted by metadata + sql("DELETE FROM %s AS t WHERE t.id IS NULL", commitTarget()); + + // Metadata Delete + Table table = Spark3Util.loadIcebergTable(spark, tableName); + List dataFilesBefore = TestHelpers.dataFiles(table, branch); + + sql("DELETE FROM %s AS t WHERE t.id = 1", commitTarget()); + + List dataFilesAfter = TestHelpers.dataFiles(table, branch); + assertThat(dataFilesAfter) + .as("Data file should have been removed") + .hasSizeLessThan(dataFilesBefore.size()); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testDeleteWithPartitionedTable() throws Exception { + createAndInitPartitionedTable(); + + append(tableName, new Employee(1, "hr"), new Employee(3, "hr")); + append(tableName, new Employee(1, "hardware"), new Employee(2, "hardware")); + + // row level delete + sql("DELETE FROM %s WHERE id = 1", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2, "hardware"), row(3, "hr")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + List rowLevelDeletePartitions = + spark.sql("SELECT * FROM " + tableName + ".partitions ").collectAsList(); + assertThat(rowLevelDeletePartitions) + .as("row level delete does not reduce number of partition") + .hasSize(2); + + // partition aligned delete + sql("DELETE FROM %s WHERE dep = 'hr'", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + List actualPartitions = + spark.sql("SELECT * FROM " + tableName + ".partitions ").collectAsList(); + assertThat(actualPartitions).as("partition aligned delete results in 1 partition").hasSize(1); + } + + @TestTemplate + public void testDeleteWithFalseCondition() { + createAndInitUnpartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware')", tableName); + createBranchIfNeeded(); + + sql("DELETE FROM %s WHERE id = 1 AND id > 20", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).as("Should have 2 snapshots").hasSize(2); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testDeleteFromEmptyTable() { + assumeThat(branch).as("Custom branch does not exist for empty table").isNotEqualTo("test"); + createAndInitUnpartitionedTable(); + + sql("DELETE FROM %s WHERE id IN (1)", commitTarget()); + sql("DELETE FROM %s WHERE dep = 'hr'", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).as("Should have 2 snapshots").hasSize(2); + + assertEquals( + "Should have expected rows", + ImmutableList.of(), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testDeleteFromNonExistingCustomBranch() { + assumeThat(branch).as("Test only applicable to custom branch").isEqualTo("test"); + createAndInitUnpartitionedTable(); + + assertThatThrownBy(() -> sql("DELETE FROM %s WHERE id IN (1)", commitTarget())) + .isInstanceOf(ValidationException.class) + .hasMessage("Cannot use branch (does not exist): test"); + } + + @TestTemplate + public void testExplain() { + createAndInitUnpartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + createBranchIfNeeded(); + + sql("EXPLAIN DELETE FROM %s WHERE id <=> 1", commitTarget()); + + sql("EXPLAIN DELETE FROM %s WHERE true", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).as("Should have 1 snapshot").hasSize(1); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", commitTarget())); + } + + @TestTemplate + public void testDeleteWithAlias() { + createAndInitUnpartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + createBranchIfNeeded(); + + sql("DELETE FROM %s AS t WHERE t.id IS NULL", commitTarget()); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testDeleteWithDynamicFileFiltering() throws NoSuchTableException { + createAndInitPartitionedTable(); + + append(tableName, new Employee(1, "hr"), new Employee(3, "hr")); + createBranchIfNeeded(); + append(new Employee(1, "hardware"), new Employee(2, "hardware")); + + sql("DELETE FROM %s WHERE id = 2", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).as("Should have 3 snapshots").hasSize(3); + + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + if (mode(table) == COPY_ON_WRITE) { + validateCopyOnWrite(currentSnapshot, "1", "1", "1"); + } else { + validateMergeOnRead(currentSnapshot, "1", "1", null); + } + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hardware"), row(1, "hr"), row(3, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); + } + + @TestTemplate + public void testDeleteNonExistingRecords() { + createAndInitPartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + createBranchIfNeeded(); + + sql("DELETE FROM %s AS t WHERE t.id > 10", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).as("Should have 2 snapshots").hasSize(2); + + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + + if (fileFormat.equals(FileFormat.ORC) || fileFormat.equals(FileFormat.PARQUET)) { + validateDelete(currentSnapshot, "0", null); + } else { + if (mode(table) == COPY_ON_WRITE) { + validateCopyOnWrite(currentSnapshot, "0", null, null); + } else { + validateMergeOnRead(currentSnapshot, "0", null, null); + } + } + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @TestTemplate + public void deleteSingleRecordProducesDeleteOperation() throws NoSuchTableException { + createAndInitPartitionedTable(); + append(tableName, new Employee(1, "eng"), new Employee(2, "eng"), new Employee(3, "eng")); + + sql("DELETE FROM %s WHERE id = 2", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).hasSize(2); + + Snapshot currentSnapshot = table.currentSnapshot(); + + if (mode(table) == COPY_ON_WRITE) { + // this is an OverwriteFiles and produces "overwrite" + validateCopyOnWrite(currentSnapshot, "1", "1", "1"); + } else { + // this is a RowDelta that produces a "delete" instead of "overwrite" + validateMergeOnRead(currentSnapshot, "1", "1", null); + validateProperty(currentSnapshot, ADD_POS_DELETE_FILES_PROP, "1"); + } + + assertThat(sql("SELECT * FROM %s", tableName)) + .containsExactlyInAnyOrder(row(1, "eng"), row(3, "eng")); + } + + @TestTemplate + public void testDeleteWithoutCondition() { + createAndInitPartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); + sql("INSERT INTO TABLE %s VALUES (2, 'hardware')", commitTarget()); + sql("INSERT INTO TABLE %s VALUES (null, 'hr')", commitTarget()); + + sql("DELETE FROM %s", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).as("Should have 4 snapshots").hasSize(4); + + // should be a delete instead of an overwrite as it is done through a metadata operation + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + validateDelete(currentSnapshot, "2", "3"); + + assertEquals( + "Should have expected rows", ImmutableList.of(), sql("SELECT * FROM %s", commitTarget())); + } + + @TestTemplate + public void testDeleteUsingMetadataWithComplexCondition() { + createAndInitPartitionedTable(); + + sql("INSERT INTO %s VALUES (1, 'dep1')", tableName); + createBranchIfNeeded(); + sql("INSERT INTO %s VALUES (2, 'dep2')", commitTarget()); + sql("INSERT INTO %s VALUES (null, 'dep3')", commitTarget()); + + sql( + "DELETE FROM %s WHERE dep > 'dep2' OR dep = CAST(4 AS STRING) OR dep = 'dep2'", + commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).as("Should have 4 snapshots").hasSize(4); + + // should be a delete instead of an overwrite as it is done through a metadata operation + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + validateDelete(currentSnapshot, "2", "2"); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "dep1")), + sql("SELECT * FROM %s", selectTarget())); + } + + @TestTemplate + public void testDeleteWithArbitraryPartitionPredicates() { + createAndInitPartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); + sql("INSERT INTO TABLE %s VALUES (2, 'hardware')", commitTarget()); + sql("INSERT INTO TABLE %s VALUES (null, 'hr')", commitTarget()); + + // %% is an escaped version of % + sql("DELETE FROM %s WHERE id = 10 OR dep LIKE '%%ware'", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).as("Should have 4 snapshots").hasSize(4); + + // should be a "delete" instead of an "overwrite" as only data files have been removed (COW) / + // delete files have been added (MOR) + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + assertThat(currentSnapshot.operation()).isEqualTo(DELETE); + if (mode(table) == COPY_ON_WRITE) { + validateCopyOnWrite(currentSnapshot, "1", "1", null); + } else { + validateMergeOnRead(currentSnapshot, "1", "1", null); + } + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @TestTemplate + public void testDeleteWithNonDeterministicCondition() { + createAndInitPartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware')", tableName); + createBranchIfNeeded(); + + assertThatThrownBy(() -> sql("DELETE FROM %s WHERE id = 1 AND rand() > 0.5", commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("The operator expects a deterministic expression"); + } + + @TestTemplate + public void testDeleteWithFoldableConditions() { + createAndInitPartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware')", tableName); + createBranchIfNeeded(); + + // should keep all rows and don't trigger execution + sql("DELETE FROM %s WHERE false", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + + // should keep all rows and don't trigger execution + sql("DELETE FROM %s WHERE 50 <> 50", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + + // should keep all rows and don't trigger execution + sql("DELETE FROM %s WHERE 1 > null", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + + // should remove all rows + sql("DELETE FROM %s WHERE 21 = 21", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).as("Should have 2 snapshots").hasSize(2); + } + + @TestTemplate + public void testDeleteWithNullConditions() { + createAndInitPartitionedTable(); + + sql( + "INSERT INTO TABLE %s VALUES (0, null), (1, 'hr'), (2, 'hardware'), (null, 'hr')", + tableName); + createBranchIfNeeded(); + + // should keep all rows as null is never equal to null + sql("DELETE FROM %s WHERE dep = null", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(0, null), row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + // null = 'software' -> null + // should delete using metadata operation only + sql("DELETE FROM %s WHERE dep = 'software'", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(0, null), row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + // should delete using metadata operation only + sql("DELETE FROM %s WHERE dep <=> NULL", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).as("Should have 3 snapshots").hasSize(3); + + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + validateDelete(currentSnapshot, "1", "1"); + } + + @TestTemplate + public void testDeleteWithInAndNotInConditions() { + createAndInitUnpartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + createBranchIfNeeded(); + + sql("DELETE FROM %s WHERE id IN (1, null)", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql("DELETE FROM %s WHERE id NOT IN (null, 1)", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql("DELETE FROM %s WHERE id NOT IN (1, 10)", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @TestTemplate + public void testDeleteWithMultipleRowGroupsParquet() throws NoSuchTableException { + assumeThat(fileFormat).isEqualTo(FileFormat.PARQUET); + + createAndInitPartitionedTable(); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", + tableName, PARQUET_ROW_GROUP_SIZE_BYTES, 100); + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", tableName, SPLIT_SIZE, 100); + + List ids = Lists.newArrayListWithCapacity(200); + for (int id = 1; id <= 200; id++) { + ids.add(id); + } + Dataset df = + spark + .createDataset(ids, Encoders.INT()) + .withColumnRenamed("value", "id") + .withColumn("dep", lit("hr")); + df.coalesce(1).writeTo(tableName).append(); + createBranchIfNeeded(); + + assertThat(spark.table(commitTarget()).count()).isEqualTo(200); + + // delete a record from one of two row groups and copy over the second one + sql("DELETE FROM %s WHERE id IN (200, 201)", commitTarget()); + + assertThat(spark.table(commitTarget()).count()).isEqualTo(199); + } + + @TestTemplate + public void testDeleteWithConditionOnNestedColumn() { + createAndInitNestedColumnsTable(); + + sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", 3, \"c2\", \"v1\"))", tableName); + createBranchIfNeeded(); + sql("INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 2, \"c2\", \"v2\"))", commitTarget()); + + sql("DELETE FROM %s WHERE complex.c1 = id + 2", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2)), + sql("SELECT id FROM %s", selectTarget())); + + sql("DELETE FROM %s t WHERE t.complex.c1 = id", commitTarget()); + assertEquals( + "Should have expected rows", ImmutableList.of(), sql("SELECT id FROM %s", selectTarget())); + } + + @TestTemplate + public void testDeleteWithInSubquery() throws NoSuchTableException { + createAndInitUnpartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + createBranchIfNeeded(); + + createOrReplaceView("deleted_id", Arrays.asList(0, 1, null), Encoders.INT()); + createOrReplaceView("deleted_dep", Arrays.asList("software", "hr"), Encoders.STRING()); + + sql( + "DELETE FROM %s WHERE id IN (SELECT * FROM deleted_id) AND dep IN (SELECT * from deleted_dep)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + append(new Employee(1, "hr"), new Employee(-1, "hr")); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql( + "DELETE FROM %s WHERE id IS NULL OR id IN (SELECT value + 2 FROM deleted_id)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(1, "hr")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + + append(new Employee(null, "hr"), new Employee(2, "hr")); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(1, "hr"), row(2, "hr"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql( + "DELETE FROM %s WHERE id IN (SELECT value + 2 FROM deleted_id) AND dep = 'hr'", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(1, "hr"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @TestTemplate + public void testDeleteWithMultiColumnInSubquery() throws NoSuchTableException { + createAndInitUnpartitionedTable(); + + append(tableName, new Employee(1, "hr"), new Employee(2, "hardware"), new Employee(null, "hr")); + createBranchIfNeeded(); + + List deletedEmployees = + Arrays.asList(new Employee(null, "hr"), new Employee(1, "hr")); + createOrReplaceView("deleted_employee", deletedEmployees, Encoders.bean(Employee.class)); + + sql("DELETE FROM %s WHERE (id, dep) IN (SELECT id, dep FROM deleted_employee)", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @TestTemplate + public void testDeleteWithNotInSubquery() throws NoSuchTableException { + createAndInitUnpartitionedTable(); + + append(tableName, new Employee(1, "hr"), new Employee(2, "hardware"), new Employee(null, "hr")); + createBranchIfNeeded(); + + createOrReplaceView("deleted_id", Arrays.asList(-1, -2, null), Encoders.INT()); + createOrReplaceView("deleted_dep", Arrays.asList("software", "hr"), Encoders.STRING()); + + // the file filter subquery (nested loop lef-anti join) returns 0 records + sql("DELETE FROM %s WHERE id NOT IN (SELECT * FROM deleted_id)", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql( + "DELETE FROM %s WHERE id NOT IN (SELECT * FROM deleted_id WHERE value IS NOT NULL)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql( + "DELETE FROM %s WHERE id NOT IN (SELECT * FROM deleted_id) OR dep IN ('software', 'hr')", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql( + "DELETE FROM %s t WHERE " + + "id NOT IN (SELECT * FROM deleted_id WHERE value IS NOT NULL) AND " + + "EXISTS (SELECT 1 FROM FROM deleted_dep WHERE t.dep = deleted_dep.value)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql( + "DELETE FROM %s t WHERE " + + "id NOT IN (SELECT * FROM deleted_id WHERE value IS NOT NULL) OR " + + "EXISTS (SELECT 1 FROM FROM deleted_dep WHERE t.dep = deleted_dep.value)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @TestTemplate + public void testDeleteOnNonIcebergTableNotSupported() { + assumeThat(catalogName).isEqualToIgnoringCase("spark_catalog"); + + sql("CREATE TABLE parquet_table (c1 INT, c2 INT) USING parquet"); + + assertThatThrownBy(() -> sql("DELETE FROM parquet_table WHERE c1 = -100")) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("does not support DELETE"); + } + + @TestTemplate + public void testDeleteWithExistSubquery() throws NoSuchTableException { + createAndInitUnpartitionedTable(); + + append(tableName, new Employee(1, "hr"), new Employee(2, "hardware"), new Employee(null, "hr")); + createBranchIfNeeded(); + + createOrReplaceView("deleted_id", Arrays.asList(-1, -2, null), Encoders.INT()); + createOrReplaceView("deleted_dep", Arrays.asList("software", "hr"), Encoders.STRING()); + + sql( + "DELETE FROM %s t WHERE EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql( + "DELETE FROM %s t WHERE EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value + 2)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql( + "DELETE FROM %s t WHERE EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value) OR t.id IS NULL", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2, "hardware")), + sql("SELECT * FROM %s", selectTarget())); + + sql( + "DELETE FROM %s t WHERE " + + "EXISTS (SELECT 1 FROM deleted_id di WHERE t.id = di.value) AND " + + "EXISTS (SELECT 1 FROM deleted_dep dd WHERE t.dep = dd.value)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2, "hardware")), + sql("SELECT * FROM %s", selectTarget())); + } + + @TestTemplate + public void testDeleteWithNotExistsSubquery() throws NoSuchTableException { + createAndInitUnpartitionedTable(); + + append(tableName, new Employee(1, "hr"), new Employee(2, "hardware"), new Employee(null, "hr")); + createBranchIfNeeded(); + + createOrReplaceView("deleted_id", Arrays.asList(-1, -2, null), Encoders.INT()); + createOrReplaceView("deleted_dep", Arrays.asList("software", "hr"), Encoders.STRING()); + + sql( + "DELETE FROM %s t WHERE " + + "NOT EXISTS (SELECT 1 FROM deleted_id di WHERE t.id = di.value + 2) AND " + + "NOT EXISTS (SELECT 1 FROM deleted_dep dd WHERE t.dep = dd.value)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql( + "DELETE FROM %s t WHERE NOT EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value + 2)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + String subquery = "SELECT 1 FROM deleted_id d WHERE t.id = d.value + 2"; + sql("DELETE FROM %s t WHERE NOT EXISTS (%s) OR t.id = 1", commitTarget(), subquery); + assertEquals( + "Should have expected rows", + ImmutableList.of(), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @TestTemplate + public void testDeleteWithScalarSubquery() throws NoSuchTableException { + createAndInitUnpartitionedTable(); + + append(tableName, new Employee(1, "hr"), new Employee(2, "hardware"), new Employee(null, "hr")); + createBranchIfNeeded(); + + createOrReplaceView("deleted_id", Arrays.asList(1, 100, null), Encoders.INT()); + + // TODO: Spark does not support AQE and DPP with aggregates at the moment + withSQLConf( + ImmutableMap.of(SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), "false"), + () -> { + sql("DELETE FROM %s t WHERE id <= (SELECT min(value) FROM deleted_id)", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + }); + } + + @TestTemplate + public void testDeleteThatRequiresGroupingBeforeWrite() throws NoSuchTableException { + createAndInitPartitionedTable(); + + append(tableName, new Employee(0, "hr"), new Employee(1, "hr"), new Employee(2, "hr")); + createBranchIfNeeded(); + append(new Employee(0, "ops"), new Employee(1, "ops"), new Employee(2, "ops")); + append(new Employee(0, "hr"), new Employee(1, "hr"), new Employee(2, "hr")); + append(new Employee(0, "ops"), new Employee(1, "ops"), new Employee(2, "ops")); + + createOrReplaceView("deleted_id", Arrays.asList(1, 100), Encoders.INT()); + + String originalNumOfShufflePartitions = spark.conf().get("spark.sql.shuffle.partitions"); + try { + // set the num of shuffle partitions to 1 to ensure we have only 1 writing task + spark.conf().set("spark.sql.shuffle.partitions", "1"); + + sql("DELETE FROM %s t WHERE id IN (SELECT * FROM deleted_id)", commitTarget()); + assertThat(spark.table(commitTarget()).count()) + .as("Should have expected num of rows") + .isEqualTo(8L); + } finally { + spark.conf().set("spark.sql.shuffle.partitions", originalNumOfShufflePartitions); + } + } + + @TestTemplate + public synchronized void testDeleteWithSerializableIsolation() throws InterruptedException { + // cannot run tests with concurrency for Hadoop tables without atomic renames + assumeThat(catalogName).isNotEqualToIgnoringCase("testhadoop"); + // if caching is off, the table is eagerly refreshed during runtime filtering + // this can cause a validation exception as concurrent changes would be visible + assumeThat(cachingCatalogEnabled()).isTrue(); + + createAndInitUnpartitionedTable(); + createOrReplaceView("deleted_id", Collections.singletonList(1), Encoders.INT()); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, DELETE_ISOLATION_LEVEL, "serializable"); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); + + ExecutorService executorService = + MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) Executors.newFixedThreadPool(2)); + + AtomicInteger barrier = new AtomicInteger(0); + AtomicBoolean shouldAppend = new AtomicBoolean(true); + + // delete thread + Future deleteFuture = + executorService.submit( + () -> { + for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) { + int currentNumOperations = numOperations; + Awaitility.await() + .pollInterval(10, TimeUnit.MILLISECONDS) + .atMost(5, TimeUnit.SECONDS) + .until(() -> barrier.get() >= currentNumOperations * 2); + + sql("DELETE FROM %s WHERE id IN (SELECT * FROM deleted_id)", commitTarget()); + + barrier.incrementAndGet(); + } + }); + + // append thread + Future appendFuture = + executorService.submit( + () -> { + // load the table via the validation catalog to use another table instance + Table table = validationCatalog.loadTable(tableIdent); + + GenericRecord record = GenericRecord.create(SnapshotUtil.schemaFor(table, branch)); + record.set(0, 1); // id + record.set(1, "hr"); // dep + + for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) { + int currentNumOperations = numOperations; + Awaitility.await() + .pollInterval(10, TimeUnit.MILLISECONDS) + .atMost(5, TimeUnit.SECONDS) + .until(() -> !shouldAppend.get() || barrier.get() >= currentNumOperations * 2); + + if (!shouldAppend.get()) { + return; + } + + for (int numAppends = 0; numAppends < 5; numAppends++) { + DataFile dataFile = writeDataFile(table, ImmutableList.of(record)); + AppendFiles appendFiles = table.newFastAppend().appendFile(dataFile); + if (branch != null) { + appendFiles.toBranch(branch); + } + + appendFiles.commit(); + } + + barrier.incrementAndGet(); + } + }); + + try { + assertThatThrownBy(deleteFuture::get) + .isInstanceOf(ExecutionException.class) + .cause() + .isInstanceOf(ValidationException.class) + .hasMessageContaining("Found conflicting files that can contain"); + } finally { + shouldAppend.set(false); + appendFuture.cancel(true); + } + + executorService.shutdown(); + assertThat(executorService.awaitTermination(2, TimeUnit.MINUTES)).as("Timeout").isTrue(); + } + + @TestTemplate + public synchronized void testDeleteWithSnapshotIsolation() + throws InterruptedException, ExecutionException { + // cannot run tests with concurrency for Hadoop tables without atomic renames + assumeThat(catalogName).isNotEqualToIgnoringCase("testhadoop"); + // if caching is off, the table is eagerly refreshed during runtime filtering + // this can cause a validation exception as concurrent changes would be visible + assumeThat(cachingCatalogEnabled()).isTrue(); + + createAndInitUnpartitionedTable(); + createOrReplaceView("deleted_id", Collections.singletonList(1), Encoders.INT()); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, DELETE_ISOLATION_LEVEL, "snapshot"); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); + + ExecutorService executorService = + MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) Executors.newFixedThreadPool(2)); + + AtomicInteger barrier = new AtomicInteger(0); + AtomicBoolean shouldAppend = new AtomicBoolean(true); + + // delete thread + Future deleteFuture = + executorService.submit( + () -> { + for (int numOperations = 0; numOperations < 20; numOperations++) { + int currentNumOperations = numOperations; + Awaitility.await() + .pollInterval(10, TimeUnit.MILLISECONDS) + .atMost(5, TimeUnit.SECONDS) + .until(() -> barrier.get() >= currentNumOperations * 2); + + sql("DELETE FROM %s WHERE id IN (SELECT * FROM deleted_id)", commitTarget()); + + barrier.incrementAndGet(); + } + }); + + // append thread + Future appendFuture = + executorService.submit( + () -> { + // load the table via the validation catalog to use another table instance for inserts + Table table = validationCatalog.loadTable(tableIdent); + + GenericRecord record = GenericRecord.create(SnapshotUtil.schemaFor(table, branch)); + record.set(0, 1); // id + record.set(1, "hr"); // dep + + for (int numOperations = 0; numOperations < 20; numOperations++) { + int currentNumOperations = numOperations; + Awaitility.await() + .pollInterval(10, TimeUnit.MILLISECONDS) + .atMost(5, TimeUnit.SECONDS) + .until(() -> !shouldAppend.get() || barrier.get() >= currentNumOperations * 2); + + if (!shouldAppend.get()) { + return; + } + + for (int numAppends = 0; numAppends < 5; numAppends++) { + DataFile dataFile = writeDataFile(table, ImmutableList.of(record)); + AppendFiles appendFiles = table.newFastAppend().appendFile(dataFile); + if (branch != null) { + appendFiles.toBranch(branch); + } + + appendFiles.commit(); + } + + barrier.incrementAndGet(); + } + }); + + try { + deleteFuture.get(); + } finally { + shouldAppend.set(false); + appendFuture.cancel(true); + } + + executorService.shutdown(); + assertThat(executorService.awaitTermination(2, TimeUnit.MINUTES)).as("Timeout").isTrue(); + } + + @TestTemplate + public void testDeleteRefreshesRelationCache() throws NoSuchTableException { + createAndInitPartitionedTable(); + + append(tableName, new Employee(1, "hr"), new Employee(3, "hr")); + createBranchIfNeeded(); + append(new Employee(1, "hardware"), new Employee(2, "hardware")); + + Dataset query = spark.sql("SELECT * FROM " + commitTarget() + " WHERE id = 1"); + query.createOrReplaceTempView("tmp"); + + spark.sql("CACHE TABLE tmp"); + + assertEquals( + "View should have correct data", + ImmutableList.of(row(1, "hardware"), row(1, "hr")), + sql("SELECT * FROM tmp ORDER BY id, dep")); + + sql("DELETE FROM %s WHERE id = 1", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).as("Should have 3 snapshots").hasSize(3); + + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + if (mode(table) == COPY_ON_WRITE) { + validateCopyOnWrite(currentSnapshot, "2", "2", "2"); + } else { + validateMergeOnRead(currentSnapshot, "2", "2", null); + } + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2, "hardware"), row(3, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", commitTarget())); + + assertEquals( + "Should refresh the relation cache", + ImmutableList.of(), + sql("SELECT * FROM tmp ORDER BY id, dep")); + + spark.sql("UNCACHE TABLE tmp"); + } + + @TestTemplate + public void testDeleteWithMultipleSpecs() { + createAndInitTable("id INT, dep STRING, category STRING"); + + // write an unpartitioned file + append(tableName, "{ \"id\": 1, \"dep\": \"hr\", \"category\": \"c1\"}"); + createBranchIfNeeded(); + + // write a file partitioned by dep + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + append( + commitTarget(), + "{ \"id\": 2, \"dep\": \"hr\", \"category\": \"c1\" }\n" + + "{ \"id\": 3, \"dep\": \"hr\", \"category\": \"c1\" }"); + + // write a file partitioned by dep and category + sql("ALTER TABLE %s ADD PARTITION FIELD category", tableName); + append(commitTarget(), "{ \"id\": 5, \"dep\": \"hr\", \"category\": \"c1\"}"); + + // write another file partitioned by dep + sql("ALTER TABLE %s DROP PARTITION FIELD category", tableName); + append(commitTarget(), "{ \"id\": 7, \"dep\": \"hr\", \"category\": \"c1\"}"); + + sql("DELETE FROM %s WHERE id IN (1, 3, 5, 7)", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).as("Should have 5 snapshots").hasSize(5); + + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + if (mode(table) == COPY_ON_WRITE) { + validateCopyOnWrite(currentSnapshot, "3", "4", "1"); + } else { + validateMergeOnRead(currentSnapshot, "3", "3", null); + } + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2, "hr", "c1")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testDeleteToWapBranch() throws NoSuchTableException { + assumeThat(branch).as("WAP branch only works for table identifier without branch").isNull(); + + createAndInitPartitionedTable(); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' = 'true')", + tableName, TableProperties.WRITE_AUDIT_PUBLISH_ENABLED); + append(new Employee(0, "hr"), new Employee(1, "hr"), new Employee(2, "hr")); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.WAP_BRANCH, "wap"), + () -> { + sql("DELETE FROM %s t WHERE id=0", tableName); + assertThat(spark.table(tableName).count()) + .as("Should have expected num of rows when reading table") + .isEqualTo(2L); + assertThat(spark.table(tableName + ".branch_wap").count()) + .as("Should have expected num of rows when reading WAP branch") + .isEqualTo(2L); + assertThat(spark.table(tableName + ".branch_main").count()) + .as("Should not modify main branch") + .isEqualTo(3L); + }); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.WAP_BRANCH, "wap"), + () -> { + sql("DELETE FROM %s t WHERE id=1", tableName); + assertThat(spark.table(tableName).count()) + .as("Should have expected num of rows when reading table with multiple writes") + .isEqualTo(1L); + assertThat(spark.table(tableName + ".branch_wap").count()) + .as("Should have expected num of rows when reading WAP branch with multiple writes") + .isEqualTo(1L); + assertThat(spark.table(tableName + ".branch_main").count()) + .as("Should not modify main branch with multiple writes") + .isEqualTo(3L); + }); + } + + @TestTemplate + public void testDeleteToWapBranchWithTableBranchIdentifier() throws NoSuchTableException { + assumeThat(branch).as("Test must have branch name part in table identifier").isNotNull(); + + createAndInitPartitionedTable(); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' = 'true')", + tableName, TableProperties.WRITE_AUDIT_PUBLISH_ENABLED); + append(tableName, new Employee(0, "hr"), new Employee(1, "hr"), new Employee(2, "hr")); + createBranchIfNeeded(); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.WAP_BRANCH, "wap"), + () -> + assertThatThrownBy(() -> sql("DELETE FROM %s t WHERE id=0", commitTarget())) + .isInstanceOf(ValidationException.class) + .hasMessage( + String.format( + "Cannot write to both branch and WAP branch, but got branch [%s] and WAP branch [wap]", + branch))); + } + + @TestTemplate + public void testDeleteToCustomWapBranchWithoutWhereClause() throws NoSuchTableException { + assumeThat(branch) + .as("Run only if custom WAP branch is not main") + .isNotNull() + .isNotEqualTo(SnapshotRef.MAIN_BRANCH); + + createAndInitPartitionedTable(); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' = 'true')", + tableName, TableProperties.WRITE_AUDIT_PUBLISH_ENABLED); + append(tableName, new Employee(0, "hr"), new Employee(1, "hr"), new Employee(2, "hr")); + createBranchIfNeeded(); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.WAP_BRANCH, branch), + () -> { + sql("DELETE FROM %s t WHERE id=1", tableName); + assertThat(spark.table(tableName).count()).isEqualTo(2L); + assertThat(spark.table(tableName + ".branch_" + branch).count()).isEqualTo(2L); + assertThat(spark.table(tableName + ".branch_main").count()) + .as("Should not modify main branch") + .isEqualTo(3L); + }); + withSQLConf( + ImmutableMap.of(SparkSQLProperties.WAP_BRANCH, branch), + () -> { + sql("DELETE FROM %s t", tableName); + assertThat(spark.table(tableName).count()).isEqualTo(0L); + assertThat(spark.table(tableName + ".branch_" + branch).count()).isEqualTo(0L); + assertThat(spark.table(tableName + ".branch_main").count()) + .as("Should not modify main branch") + .isEqualTo(3L); + }); + } + + @TestTemplate + public void testDeleteWithFilterOnNestedColumn() { + createAndInitNestedColumnsTable(); + + sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", 3, \"c2\", \"v1\"))", tableName); + sql("INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 2, \"c2\", \"v2\"))", tableName); + + sql("DELETE FROM %s WHERE complex.c1 > 3", tableName); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1), row(2)), + sql("SELECT id FROM %s order by id", tableName)); + + sql("DELETE FROM %s WHERE complex.c1 = 3", tableName); + assertEquals( + "Should have expected rows", ImmutableList.of(row(2)), sql("SELECT id FROM %s", tableName)); + + sql("DELETE FROM %s t WHERE t.complex.c1 = 2", tableName); + assertEquals( + "Should have expected rows", ImmutableList.of(), sql("SELECT id FROM %s", tableName)); + } + + // TODO: multiple stripes for ORC + + protected void createAndInitPartitionedTable() { + sql("CREATE TABLE %s (id INT, dep STRING) USING iceberg PARTITIONED BY (dep)", tableName); + initTable(); + } + + protected void createAndInitUnpartitionedTable() { + sql("CREATE TABLE %s (id INT, dep STRING) USING iceberg", tableName); + initTable(); + } + + protected void createAndInitNestedColumnsTable() { + sql("CREATE TABLE %s (id INT, complex STRUCT) USING iceberg", tableName); + initTable(); + } + + protected void append(Employee... employees) throws NoSuchTableException { + append(commitTarget(), employees); + } + + protected void append(String target, Employee... employees) throws NoSuchTableException { + List input = Arrays.asList(employees); + Dataset inputDF = spark.createDataFrame(input, Employee.class); + inputDF.coalesce(1).writeTo(target).append(); + } + + private RowLevelOperationMode mode(Table table) { + String modeName = table.properties().getOrDefault(DELETE_MODE, DELETE_MODE_DEFAULT); + return RowLevelOperationMode.fromName(modeName); + } + + private LogicalPlan parsePlan(String query, Object... args) { + try { + return spark.sessionState().sqlParser().parsePlan(String.format(query, args)); + } catch (ParseException e) { + throw new RuntimeException(e); + } + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestExpireSnapshotsProcedure.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestExpireSnapshotsProcedure.java new file mode 100644 index 000000000000..34fec09add7c --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestExpireSnapshotsProcedure.java @@ -0,0 +1,560 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.TableProperties.GC_ENABLED; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.sql.Timestamp; +import java.time.Instant; +import java.util.List; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.GenericBlobMetadata; +import org.apache.iceberg.GenericStatisticsFile; +import org.apache.iceberg.PartitionStatisticsFile; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.StatisticsFile; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.puffin.Blob; +import org.apache.iceberg.puffin.Puffin; +import org.apache.iceberg.puffin.PuffinWriter; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestExpireSnapshotsProcedure extends ExtensionsTestBase { + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testExpireSnapshotsInEmptyTable() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + List output = sql("CALL %s.system.expire_snapshots('%s')", catalogName, tableIdent); + assertEquals( + "Should not delete any files", ImmutableList.of(row(0L, 0L, 0L, 0L, 0L, 0L)), output); + } + + @TestTemplate + public void testExpireSnapshotsUsingPositionalArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + + waitUntilAfter(firstSnapshot.timestampMillis()); + + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + Timestamp secondSnapshotTimestamp = + Timestamp.from(Instant.ofEpochMilli(secondSnapshot.timestampMillis())); + + assertThat(table.snapshots()).as("Should be 2 snapshots").hasSize(2); + + // expire without retainLast param + List output1 = + sql( + "CALL %s.system.expire_snapshots('%s', TIMESTAMP '%s')", + catalogName, tableIdent, secondSnapshotTimestamp); + assertEquals( + "Procedure output must match", ImmutableList.of(row(0L, 0L, 0L, 0L, 1L, 0L)), output1); + + table.refresh(); + + assertThat(table.snapshots()).as("Should expire one snapshot").hasSize(1); + + sql("INSERT OVERWRITE %s VALUES (3, 'c')", tableName); + sql("INSERT INTO TABLE %s VALUES (4, 'd')", tableName); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(3L, "c"), row(4L, "d")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + waitUntilAfter(table.currentSnapshot().timestampMillis()); + + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + + assertThat(table.snapshots()).as("Should be 3 snapshots").hasSize(3); + + // expire with retainLast param + List output = + sql( + "CALL %s.system.expire_snapshots('%s', TIMESTAMP '%s', 2)", + catalogName, tableIdent, currentTimestamp); + assertEquals( + "Procedure output must match", ImmutableList.of(row(2L, 0L, 0L, 2L, 1L, 0L)), output); + } + + @TestTemplate + public void testExpireSnapshotUsingNamedArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + assertThat(table.snapshots()).as("Should be 2 snapshots").hasSize(2); + + waitUntilAfter(table.currentSnapshot().timestampMillis()); + + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + + List output = + sql( + "CALL %s.system.expire_snapshots(older_than => TIMESTAMP '%s',table => '%s')", + catalogName, currentTimestamp, tableIdent); + assertEquals( + "Procedure output must match", ImmutableList.of(row(0L, 0L, 0L, 0L, 1L, 0L)), output); + } + + @TestTemplate + public void testExpireSnapshotsGCDisabled() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'false')", tableName, GC_ENABLED); + + assertThatThrownBy(() -> sql("CALL %s.system.expire_snapshots('%s')", catalogName, tableIdent)) + .isInstanceOf(ValidationException.class) + .hasMessageStartingWith("Cannot expire snapshots: GC is disabled"); + } + + @TestTemplate + public void testInvalidExpireSnapshotsCases() { + assertThatThrownBy(() -> sql("CALL %s.system.expire_snapshots('n', table => 't')", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessage("Named and positional arguments cannot be mixed"); + + assertThatThrownBy(() -> sql("CALL %s.custom.expire_snapshots('n', 't')", catalogName)) + .isInstanceOf(ParseException.class) + .satisfies( + exception -> { + ParseException parseException = (ParseException) exception; + assertThat(parseException.getErrorClass()).isEqualTo("PARSE_SYNTAX_ERROR"); + assertThat(parseException.getMessageParameters().get("error")).isEqualTo("'CALL'"); + }); + + assertThatThrownBy(() -> sql("CALL %s.system.expire_snapshots()", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessage("Missing required parameters: [table]"); + + assertThatThrownBy(() -> sql("CALL %s.system.expire_snapshots('n', 2.2)", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessage("Wrong arg type for older_than: cannot cast DecimalType(2,1) to TimestampType"); + + assertThatThrownBy(() -> sql("CALL %s.system.expire_snapshots('')", catalogName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot handle an empty identifier for argument table"); + } + + @TestTemplate + public void testResolvingTableInAnotherCatalog() throws IOException { + String anotherCatalog = "another_" + catalogName; + spark.conf().set("spark.sql.catalog." + anotherCatalog, SparkCatalog.class.getName()); + spark.conf().set("spark.sql.catalog." + anotherCatalog + ".type", "hadoop"); + spark + .conf() + .set( + "spark.sql.catalog." + anotherCatalog + ".warehouse", + Files.createTempDirectory(temp, "junit").toFile().toURI().toString()); + + sql( + "CREATE TABLE %s.%s (id bigint NOT NULL, data string) USING iceberg", + anotherCatalog, tableIdent); + + assertThatThrownBy( + () -> + sql( + "CALL %s.system.expire_snapshots('%s')", + catalogName, anotherCatalog + "." + tableName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageStartingWith("Cannot run procedure in catalog"); + } + + @TestTemplate + public void testConcurrentExpireSnapshots() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + sql("INSERT INTO TABLE %s VALUES (3, 'c')", tableName); + sql("INSERT INTO TABLE %s VALUES (4, 'd')", tableName); + + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + List output = + sql( + "CALL %s.system.expire_snapshots(" + + "older_than => TIMESTAMP '%s'," + + "table => '%s'," + + "max_concurrent_deletes => %s)", + catalogName, currentTimestamp, tableIdent, 4); + assertEquals( + "Expiring snapshots concurrently should succeed", + ImmutableList.of(row(0L, 0L, 0L, 0L, 3L, 0L)), + output); + } + + @TestTemplate + public void testConcurrentExpireSnapshotsWithInvalidInput() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + assertThatThrownBy( + () -> + sql( + "CALL %s.system.expire_snapshots(table => '%s', max_concurrent_deletes => %s)", + catalogName, tableIdent, 0)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("max_concurrent_deletes should have value > 0, value: 0"); + + assertThatThrownBy( + () -> + sql( + "CALL %s.system.expire_snapshots(table => '%s', max_concurrent_deletes => %s)", + catalogName, tableIdent, -1)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("max_concurrent_deletes should have value > 0, value: -1"); + } + + @TestTemplate + public void testExpireDeleteFiles() throws Exception { + sql( + "CREATE TABLE %s (id bigint, data string) USING iceberg TBLPROPERTIES" + + "('format-version'='2', 'write.delete.mode'='merge-on-read')", + tableName); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c"), + new SimpleRecord(4, "d")); + spark + .createDataset(records, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + sql("DELETE FROM %s WHERE id=1", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + assertThat(TestHelpers.deleteManifests(table)).as("Should have 1 delete manifest").hasSize(1); + assertThat(TestHelpers.deleteFiles(table)).as("Should have 1 delete file").hasSize(1); + Path deleteManifestPath = new Path(TestHelpers.deleteManifests(table).iterator().next().path()); + Path deleteFilePath = + new Path(String.valueOf(TestHelpers.deleteFiles(table).iterator().next().location())); + + sql( + "CALL %s.system.rewrite_data_files(" + + "table => '%s'," + + "options => map(" + + "'delete-file-threshold','1'," + + "'use-starting-sequence-number', 'false'))", + catalogName, tableIdent); + table.refresh(); + + sql( + "INSERT INTO TABLE %s VALUES (5, 'e')", + tableName); // this txn moves the file to the DELETED state + sql("INSERT INTO TABLE %s VALUES (6, 'f')", tableName); // this txn removes the file reference + table.refresh(); + + assertThat(TestHelpers.deleteManifests(table)).as("Should have no delete manifests").hasSize(0); + assertThat(TestHelpers.deleteFiles(table)).as("Should have no delete files").hasSize(0); + + FileSystem localFs = FileSystem.getLocal(new Configuration()); + assertThat(localFs.exists(deleteManifestPath)) + .as("Delete manifest should still exist") + .isTrue(); + assertThat(localFs.exists(deleteFilePath)).as("Delete file should still exist").isTrue(); + + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + List output = + sql( + "CALL %s.system.expire_snapshots(older_than => TIMESTAMP '%s',table => '%s')", + catalogName, currentTimestamp, tableIdent); + + assertEquals( + "Should deleted 1 data and pos delete file and 4 manifests and lists (one for each txn)", + ImmutableList.of(row(1L, 1L, 0L, 4L, 4L, 0L)), + output); + assertThat(localFs.exists(deleteManifestPath)) + .as("Delete manifest should be removed") + .isFalse(); + assertThat(localFs.exists(deleteFilePath)).as("Delete file should be removed").isFalse(); + } + + @TestTemplate + public void testExpireSnapshotWithStreamResultsEnabled() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + assertThat(table.snapshots()).as("Should be 2 snapshots").hasSize(2); + + waitUntilAfter(table.currentSnapshot().timestampMillis()); + + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + + List output = + sql( + "CALL %s.system.expire_snapshots(" + + "older_than => TIMESTAMP '%s'," + + "table => '%s'," + + "stream_results => true)", + catalogName, currentTimestamp, tableIdent); + assertEquals( + "Procedure output must match", ImmutableList.of(row(0L, 0L, 0L, 0L, 1L, 0L)), output); + } + + @TestTemplate + public void testExpireSnapshotsWithSnapshotId() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + assertThat(table.snapshots()).as("Should be 2 snapshots").hasSize(2); + + // Expiring the snapshot specified by snapshot_id should keep only a single snapshot. + long firstSnapshotId = table.currentSnapshot().parentId(); + sql( + "CALL %s.system.expire_snapshots(" + "table => '%s'," + "snapshot_ids => ARRAY(%d))", + catalogName, tableIdent, firstSnapshotId); + + // There should only be one single snapshot left. + table.refresh(); + assertThat(table.snapshots()).as("Should be 1 snapshots").hasSize(1); + assertThat(table.snapshots()) + .as("Snapshot ID should not be present") + .filteredOn(snapshot -> snapshot.snapshotId() == firstSnapshotId) + .hasSize(0); + } + + @TestTemplate + public void testExpireSnapshotShouldFailForCurrentSnapshot() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).as("Should be 2 snapshots").hasSize(2); + + assertThatThrownBy( + () -> + sql( + "CALL %s.system.expire_snapshots(" + + "table => '%s'," + + "snapshot_ids => ARRAY(%d, %d))", + catalogName, + tableIdent, + table.currentSnapshot().snapshotId(), + table.currentSnapshot().parentId())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageStartingWith("Cannot expire"); + } + + @TestTemplate + public void testExpireSnapshotsProcedureWorksWithSqlComments() { + // Ensure that systems such as dbt, that inject comments into the generated SQL files, will + // work with Iceberg-specific DDL + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + assertThat(table.snapshots()).as("Should be 2 snapshots").hasSize(2); + + waitUntilAfter(table.currentSnapshot().timestampMillis()); + + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + + String callStatement = + "/* CALL statement is used to expire snapshots */\n" + + "-- And we have single line comments as well \n" + + "/* And comments that span *multiple* \n" + + " lines */ CALL /* this is the actual CALL */ %s.system.expire_snapshots(" + + " older_than => TIMESTAMP '%s'," + + " table => '%s')"; + List output = sql(callStatement, catalogName, currentTimestamp, tableIdent); + assertEquals( + "Procedure output must match", ImmutableList.of(row(0L, 0L, 0L, 0L, 1L, 0L)), output); + + table.refresh(); + + assertThat(table.snapshots()).as("Should be 1 snapshot remaining").hasSize(1); + } + + @TestTemplate + public void testExpireSnapshotsWithStatisticFiles() throws Exception { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (10, 'abc')", tableName); + Table table = validationCatalog.loadTable(tableIdent); + String statsFileLocation1 = ProcedureUtil.statsFileLocation(table.location()); + StatisticsFile statisticsFile1 = + writeStatsFile( + table.currentSnapshot().snapshotId(), + table.currentSnapshot().sequenceNumber(), + statsFileLocation1, + table.io()); + table.updateStatistics().setStatistics(statisticsFile1.snapshotId(), statisticsFile1).commit(); + + sql("INSERT INTO %s SELECT 20, 'def'", tableName); + table.refresh(); + String statsFileLocation2 = ProcedureUtil.statsFileLocation(table.location()); + StatisticsFile statisticsFile2 = + writeStatsFile( + table.currentSnapshot().snapshotId(), + table.currentSnapshot().sequenceNumber(), + statsFileLocation2, + table.io()); + table.updateStatistics().setStatistics(statisticsFile2.snapshotId(), statisticsFile2).commit(); + + waitUntilAfter(table.currentSnapshot().timestampMillis()); + + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + List output = + sql( + "CALL %s.system.expire_snapshots(older_than => TIMESTAMP '%s',table => '%s')", + catalogName, currentTimestamp, tableIdent); + assertThat(output.get(0)[5]).as("should be 1 deleted statistics file").isEqualTo(1L); + + table.refresh(); + assertThat(table.statisticsFiles()) + .as( + "Statistics file entry in TableMetadata should be present only for the snapshot %s", + statisticsFile2.snapshotId()) + .extracting(StatisticsFile::snapshotId) + .containsExactly(statisticsFile2.snapshotId()); + + assertThat(new File(statsFileLocation1)) + .as("Statistics file should not exist for snapshot %s", statisticsFile1.snapshotId()) + .doesNotExist(); + + assertThat(new File(statsFileLocation2)) + .as("Statistics file should exist for snapshot %s", statisticsFile2.snapshotId()) + .exists(); + } + + @TestTemplate + public void testExpireSnapshotsWithPartitionStatisticFiles() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (10, 'abc')", tableName); + Table table = validationCatalog.loadTable(tableIdent); + String partitionStatsFileLocation1 = ProcedureUtil.statsFileLocation(table.location()); + PartitionStatisticsFile partitionStatisticsFile1 = + ProcedureUtil.writePartitionStatsFile( + table.currentSnapshot().snapshotId(), partitionStatsFileLocation1, table.io()); + table.updatePartitionStatistics().setPartitionStatistics(partitionStatisticsFile1).commit(); + + sql("INSERT INTO %s SELECT 20, 'def'", tableName); + table.refresh(); + String partitionStatsFileLocation2 = ProcedureUtil.statsFileLocation(table.location()); + PartitionStatisticsFile partitionStatisticsFile2 = + ProcedureUtil.writePartitionStatsFile( + table.currentSnapshot().snapshotId(), partitionStatsFileLocation2, table.io()); + table.updatePartitionStatistics().setPartitionStatistics(partitionStatisticsFile2).commit(); + + waitUntilAfter(table.currentSnapshot().timestampMillis()); + + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + List output = + sql( + "CALL %s.system.expire_snapshots(older_than => TIMESTAMP '%s',table => '%s')", + catalogName, currentTimestamp, tableIdent); + assertThat(output.get(0)[5]).as("should be 1 deleted partition statistics file").isEqualTo(1L); + + table.refresh(); + assertThat(table.partitionStatisticsFiles()) + .as( + "partition statistics file entry in TableMetadata should be present only for the snapshot %s", + partitionStatisticsFile2.snapshotId()) + .extracting(PartitionStatisticsFile::snapshotId) + .containsExactly(partitionStatisticsFile2.snapshotId()); + + assertThat(new File(partitionStatsFileLocation1)) + .as( + "partition statistics file should not exist for snapshot %s", + partitionStatisticsFile1.snapshotId()) + .doesNotExist(); + + assertThat(new File(partitionStatsFileLocation2)) + .as( + "partition statistics file should exist for snapshot %s", + partitionStatisticsFile2.snapshotId()) + .exists(); + } + + private static StatisticsFile writeStatsFile( + long snapshotId, long snapshotSequenceNumber, String statsLocation, FileIO fileIO) + throws IOException { + try (PuffinWriter puffinWriter = Puffin.write(fileIO.newOutputFile(statsLocation)).build()) { + puffinWriter.add( + new Blob( + "some-blob-type", + ImmutableList.of(1), + snapshotId, + snapshotSequenceNumber, + ByteBuffer.wrap("blob content".getBytes(StandardCharsets.UTF_8)))); + puffinWriter.finish(); + + return new GenericStatisticsFile( + snapshotId, + statsLocation, + puffinWriter.fileSize(), + puffinWriter.footerSize(), + puffinWriter.writtenBlobsMetadata().stream() + .map(GenericBlobMetadata::from) + .collect(ImmutableList.toImmutableList())); + } + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestFastForwardBranchProcedure.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestFastForwardBranchProcedure.java new file mode 100644 index 000000000000..7eb334f70aa2 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestFastForwardBranchProcedure.java @@ -0,0 +1,260 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotRef; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestFastForwardBranchProcedure extends ExtensionsTestBase { + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testFastForwardBranchUsingPositionalArgs() { + sql("CREATE TABLE %s (id int NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + table.refresh(); + + Snapshot currSnapshot = table.currentSnapshot(); + long sourceRef = currSnapshot.snapshotId(); + + String newBranch = "testBranch"; + String tableNameWithBranch = String.format("%s.branch_%s", tableName, newBranch); + + sql("ALTER TABLE %s CREATE BRANCH %s", tableName, newBranch); + sql("INSERT INTO TABLE %s VALUES(3,'c')", tableNameWithBranch); + + table.refresh(); + long updatedRef = table.snapshot(newBranch).snapshotId(); + + assertEquals( + "Main branch should not have the newly inserted record.", + ImmutableList.of(row(1, "a"), row(2, "b")), + sql("SELECT * FROM %s order by id", tableName)); + + assertEquals( + "Test branch should have the newly inserted record.", + ImmutableList.of(row(1, "a"), row(2, "b"), row(3, "c")), + sql("SELECT * FROM %s order by id", tableNameWithBranch)); + + List output = + sql( + "CALL %s.system.fast_forward('%s', '%s', '%s')", + catalogName, tableIdent, SnapshotRef.MAIN_BRANCH, newBranch); + + assertThat(Arrays.stream(output.get(0)).collect(Collectors.toList()).get(0)) + .isEqualTo(SnapshotRef.MAIN_BRANCH); + + assertThat(Arrays.stream(output.get(0)).collect(Collectors.toList()).get(1)) + .isEqualTo(sourceRef); + + assertThat(Arrays.stream(output.get(0)).collect(Collectors.toList()).get(2)) + .isEqualTo(updatedRef); + + assertEquals( + "Main branch should have the newly inserted record.", + ImmutableList.of(row(1, "a"), row(2, "b"), row(3, "c")), + sql("SELECT * FROM %s order by id", tableName)); + } + + @TestTemplate + public void testFastForwardBranchUsingNamedArgs() { + sql("CREATE TABLE %s (id int NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + String newBranch = "testBranch"; + String tableNameWithBranch = String.format("%s.branch_%s", tableName, newBranch); + + sql("ALTER TABLE %s CREATE BRANCH %s", tableName, newBranch); + sql("INSERT INTO TABLE %s VALUES(3,'c')", tableNameWithBranch); + + assertEquals( + "Main branch should not have the newly inserted record.", + ImmutableList.of(row(1, "a"), row(2, "b")), + sql("SELECT * FROM %s order by id", tableName)); + + assertEquals( + "Test branch should have the newly inserted record.", + ImmutableList.of(row(1, "a"), row(2, "b"), row(3, "c")), + sql("SELECT * FROM %s order by id", tableNameWithBranch)); + + List output = + sql( + "CALL %s.system.fast_forward(table => '%s', branch => '%s', to => '%s')", + catalogName, tableIdent, SnapshotRef.MAIN_BRANCH, newBranch); + + assertEquals( + "Main branch should now have the newly inserted record.", + ImmutableList.of(row(1, "a"), row(2, "b"), row(3, "c")), + sql("SELECT * FROM %s order by id", tableName)); + } + + @TestTemplate + public void testFastForwardWhenTargetIsNotAncestorFails() { + sql("CREATE TABLE %s (id int NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + String newBranch = "testBranch"; + String tableNameWithBranch = String.format("%s.branch_%s", tableName, newBranch); + + sql("ALTER TABLE %s CREATE BRANCH %s", tableName, newBranch); + sql("INSERT INTO TABLE %s VALUES(3,'c')", tableNameWithBranch); + + assertEquals( + "Main branch should not have the newly inserted record.", + ImmutableList.of(row(1, "a"), row(2, "b")), + sql("SELECT * FROM %s order by id", tableName)); + + assertEquals( + "Test branch should have the newly inserted record.", + ImmutableList.of(row(1, "a"), row(2, "b"), row(3, "c")), + sql("SELECT * FROM %s order by id", tableNameWithBranch)); + + // Commit a snapshot on main to deviate the branches + sql("INSERT INTO TABLE %s VALUES (4, 'd')", tableName); + + assertThatThrownBy( + () -> + sql( + "CALL %s.system.fast_forward(table => '%s', branch => '%s', to => '%s')", + catalogName, tableIdent, SnapshotRef.MAIN_BRANCH, newBranch)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot fast-forward: main is not an ancestor of testBranch"); + } + + @TestTemplate + public void testInvalidFastForwardBranchCases() { + assertThatThrownBy( + () -> + sql( + "CALL %s.system.fast_forward('test_table', branch => 'main', to => 'newBranch')", + catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessage("Named and positional arguments cannot be mixed"); + + assertThatThrownBy( + () -> + sql("CALL %s.custom.fast_forward('test_table', 'main', 'newBranch')", catalogName)) + .isInstanceOf(ParseException.class) + .satisfies( + exception -> { + ParseException parseException = (ParseException) exception; + assertThat(parseException.getErrorClass()).isEqualTo("PARSE_SYNTAX_ERROR"); + assertThat(parseException.getMessageParameters().get("error")).isEqualTo("'CALL'"); + }); + + assertThatThrownBy(() -> sql("CALL %s.system.fast_forward('test_table', 'main')", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessage("Missing required parameters: [to]"); + + assertThatThrownBy( + () -> sql("CALL %s.system.fast_forward('', 'main', 'newBranch')", catalogName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot handle an empty identifier for argument table"); + } + + @TestTemplate + public void testFastForwardNonExistingToRefFails() { + sql("CREATE TABLE %s (id int NOT NULL, data string) USING iceberg", tableName); + assertThatThrownBy( + () -> + sql( + "CALL %s.system.fast_forward(table => '%s', branch => '%s', to => '%s')", + catalogName, tableIdent, SnapshotRef.MAIN_BRANCH, "non_existing_branch")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Ref does not exist: non_existing_branch"); + } + + @TestTemplate + public void testFastForwardNonMain() { + sql("CREATE TABLE %s (id int NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + Table table = validationCatalog.loadTable(tableIdent); + table.refresh(); + + String branch1 = "branch1"; + sql("ALTER TABLE %s CREATE BRANCH %s", tableName, branch1); + String tableNameWithBranch1 = String.format("%s.branch_%s", tableName, branch1); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableNameWithBranch1); + table.refresh(); + Snapshot branch1Snapshot = table.snapshot(branch1); + + // Create branch2 from branch1 + String branch2 = "branch2"; + sql( + "ALTER TABLE %s CREATE BRANCH %s AS OF VERSION %d", + tableName, branch2, branch1Snapshot.snapshotId()); + String tableNameWithBranch2 = String.format("%s.branch_%s", tableName, branch2); + sql("INSERT INTO TABLE %s VALUES (3, 'c')", tableNameWithBranch2); + table.refresh(); + Snapshot branch2Snapshot = table.snapshot(branch2); + assertThat( + sql( + "CALL %s.system.fast_forward('%s', '%s', '%s')", + catalogName, tableIdent, branch1, branch2)) + .containsExactly(row(branch1, branch1Snapshot.snapshotId(), branch2Snapshot.snapshotId())); + } + + @TestTemplate + public void testFastForwardNonExistingFromMainCreatesBranch() { + sql("CREATE TABLE %s (id int NOT NULL, data string) USING iceberg", tableName); + String branch1 = "branch1"; + sql("ALTER TABLE %s CREATE BRANCH %s", tableName, branch1); + String branchIdentifier = String.format("%s.branch_%s", tableName, branch1); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", branchIdentifier); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", branchIdentifier); + Table table = validationCatalog.loadTable(tableIdent); + table.refresh(); + Snapshot branch1Snapshot = table.snapshot(branch1); + + assertThat( + sql( + "CALL %s.system.fast_forward('%s', '%s', '%s')", + catalogName, tableIdent, SnapshotRef.MAIN_BRANCH, branch1)) + .containsExactly(row(SnapshotRef.MAIN_BRANCH, null, branch1Snapshot.snapshotId())); + + // Ensure the same behavior for non-main branches + String branch2 = "branch2"; + assertThat( + sql( + "CALL %s.system.fast_forward('%s', '%s', '%s')", + catalogName, tableIdent, branch2, branch1)) + .containsExactly(row(branch2, null, branch1Snapshot.snapshotId())); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java new file mode 100644 index 000000000000..8f00f625cb86 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java @@ -0,0 +1,2983 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.RowLevelOperationMode.COPY_ON_WRITE; +import static org.apache.iceberg.TableProperties.MERGE_DISTRIBUTION_MODE; +import static org.apache.iceberg.TableProperties.MERGE_ISOLATION_LEVEL; +import static org.apache.iceberg.TableProperties.MERGE_MODE; +import static org.apache.iceberg.TableProperties.MERGE_MODE_DEFAULT; +import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES; +import static org.apache.iceberg.TableProperties.SPLIT_OPEN_FILE_COST; +import static org.apache.iceberg.TableProperties.SPLIT_SIZE; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE; +import static org.apache.spark.sql.functions.lit; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assumptions.assumeThat; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.iceberg.AppendFiles; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DistributionMode; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.SparkRuntimeException; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.execution.SparkPlan; +import org.apache.spark.sql.internal.SQLConf; +import org.awaitility.Awaitility; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.TestTemplate; + +public abstract class TestMerge extends SparkRowLevelOperationsTestBase { + + @BeforeAll + public static void setupSparkConf() { + spark.conf().set("spark.sql.shuffle.partitions", "4"); + } + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS source"); + } + + @TestTemplate + public void testMergeWithAllClauses() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + + "{ \"id\": 2, \"dep\": \"emp-id-two\" }\n" + + "{ \"id\": 3, \"dep\": \"emp-id-3\" }\n" + + "{ \"id\": 4, \"dep\": \"emp-id-4\" }"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 5, \"dep\": \"emp-id-5\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET * " + + "WHEN MATCHED AND t.id = 2 THEN " + + " DELETE " + + "WHEN NOT MATCHED THEN " + + " INSERT * " + + "WHEN NOT MATCHED BY SOURCE AND t.id = 3 THEN " + + " UPDATE SET dep = 'invalid' " + + "WHEN NOT MATCHED BY SOURCE AND t.id = 4 THEN " + + " DELETE ", + commitTarget()); + + assertEquals( + "Should have expected rows", + ImmutableList.of( + row(1, "emp-id-1"), // updated (matched) + // row(2, "emp-id-two) // deleted (matched) + row(3, "invalid"), // updated (not matched by source) + // row(4, "emp-id-4) // deleted (not matched by source) + row(5, "emp-id-5")), // new + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testMergeWithOneNotMatchedBySourceClause() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 3, \"dep\": \"emp-id-3\" }\n" + + "{ \"id\": 4, \"dep\": \"emp-id-4\" }"); + + createOrReplaceView("source", ImmutableList.of(1, 4), Encoders.INT()); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.value " + + "WHEN NOT MATCHED BY SOURCE THEN " + + " DELETE ", + commitTarget()); + + assertEquals( + "Should have expected rows", + ImmutableList.of( + row(1, "emp-id-1"), // existing + // row(2, "emp-id-2) // deleted (not matched by source) + // row(3, "emp-id-3") // deleted (not matched by source) + row(4, "emp-id-4")), // existing + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testMergeNotMatchedBySourceClausesPartitionedTable() { + createAndInitTable( + "id INT, dep STRING", + "PARTITIONED BY (dep)", + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hr\" }\n" + + "{ \"id\": 3, \"dep\": \"support\" }"); + + createOrReplaceView("source", ImmutableList.of(1, 2), Encoders.INT()); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.value AND t.dep = 'hr' " + + "WHEN MATCHED THEN " + + " UPDATE SET dep = 'support' " + + "WHEN NOT MATCHED BY SOURCE THEN " + + " UPDATE SET dep = 'invalid' ", + commitTarget()); + + assertEquals( + "Should have expected rows", + ImmutableList.of( + row(1, "support"), // updated (matched) + row(2, "support"), // updated (matched) + row(3, "invalid")), // updated (not matched by source) + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testMergeWithVectorizedReads() { + assumeThat(supportsVectorization()).isTrue(); + + createAndInitTable( + "id INT, value INT, dep STRING", + "PARTITIONED BY (dep)", + "{ \"id\": 1, \"value\": 100, \"dep\": \"hr\" }\n" + + "{ \"id\": 6, \"value\": 600, \"dep\": \"software\" }"); + + createOrReplaceView( + "source", + "id INT, value INT", + "{ \"id\": 2, \"value\": 201 }\n" + + "{ \"id\": 1, \"value\": 101 }\n" + + "{ \"id\": 6, \"value\": 601 }"); + + SparkPlan plan = + executeAndKeepPlan( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET t.value = s.value " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT (id, value, dep) VALUES (s.id, s.value, 'invalid')", + commitTarget()); + + assertAllBatchScansVectorized(plan); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, 101, "hr"), // updated + row(2, 201, "invalid")); // new + + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testCoalesceMerge() { + createAndInitTable("id INT, salary INT, dep STRING"); + + String[] records = new String[100]; + for (int index = 0; index < 100; index++) { + records[index] = String.format("{ \"id\": %d, \"salary\": 100, \"dep\": \"hr\" }", index); + } + append(tableName, records); + append(tableName, records); + append(tableName, records); + append(tableName, records); + + // set the open file cost large enough to produce a separate scan task per file + // disable any write distribution + Map tableProps = + ImmutableMap.of( + SPLIT_OPEN_FILE_COST, + String.valueOf(Integer.MAX_VALUE), + MERGE_DISTRIBUTION_MODE, + DistributionMode.NONE.modeName()); + sql("ALTER TABLE %s SET TBLPROPERTIES (%s)", tableName, tablePropsAsString(tableProps)); + + createBranchIfNeeded(); + + spark.range(0, 100).createOrReplaceTempView("source"); + + // enable AQE and set the advisory partition big enough to trigger combining + // set the number of shuffle partitions to 200 to distribute the work across reducers + // disable broadcast joins to make sure the join triggers a shuffle + // set the advisory partition size for shuffles small enough to ensure writes override it + withSQLConf( + ImmutableMap.of( + SQLConf.SHUFFLE_PARTITIONS().key(), + "200", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD().key(), + "-1", + SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), + "true", + SQLConf.COALESCE_PARTITIONS_ENABLED().key(), + "true", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES().key(), + "100", + SparkSQLProperties.ADVISORY_PARTITION_SIZE, + String.valueOf(256 * 1024 * 1024)), + () -> { + sql( + "MERGE INTO %s t USING source " + + "ON t.id = source.id " + + "WHEN MATCHED THEN " + + " UPDATE SET salary = -1 ", + commitTarget()); + }); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + + if (mode(table) == COPY_ON_WRITE) { + // CoW MERGE would perform a join on `id` + // every task has data for each of 200 reducers + // AQE detects that all shuffle blocks are small and processes them in 1 task + // otherwise, there would be 200 tasks writing to the table + validateProperty(currentSnapshot, SnapshotSummary.ADDED_FILES_PROP, "1"); + } else { + // MoR MERGE would perform a join on `id` + // every task has data for each of 200 reducers + // AQE detects that all shuffle blocks are small and processes them in 1 task + // otherwise, there would be 200 tasks writing to the table + validateProperty(currentSnapshot, SnapshotSummary.ADDED_DELETE_FILES_PROP, "1"); + } + + assertThat(scalarSql("SELECT COUNT(*) FROM %s WHERE salary = -1", commitTarget())) + .as("Row count must match") + .isEqualTo(400L); + } + + @TestTemplate + public void testSkewMerge() { + createAndInitTable("id INT, salary INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + String[] records = new String[100]; + for (int index = 0; index < 100; index++) { + records[index] = String.format("{ \"id\": %d, \"salary\": 100, \"dep\": \"hr\" }", index); + } + append(tableName, records); + append(tableName, records); + append(tableName, records); + append(tableName, records); + + // set the open file cost large enough to produce a separate scan task per file + // use hash distribution to trigger a shuffle + Map tableProps = + ImmutableMap.of( + SPLIT_OPEN_FILE_COST, + String.valueOf(Integer.MAX_VALUE), + MERGE_DISTRIBUTION_MODE, + DistributionMode.HASH.modeName()); + sql("ALTER TABLE %s SET TBLPROPERTIES (%s)", tableName, tablePropsAsString(tableProps)); + + createBranchIfNeeded(); + + spark.range(0, 100).createOrReplaceTempView("source"); + + // enable AQE and set the advisory partition size small enough to trigger a split + // set the number of shuffle partitions to 2 to only have 2 reducers + // set the min coalesce partition size small enough to avoid coalescing + // set the advisory partition size for shuffles big enough to ensure writes override it + withSQLConf( + ImmutableMap.of( + SQLConf.SHUFFLE_PARTITIONS().key(), + "4", + SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_SIZE().key(), + "100", + SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), + "true", + SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED().key(), + "true", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES().key(), + "256MB", + SparkSQLProperties.ADVISORY_PARTITION_SIZE, + "100"), + () -> { + SparkPlan plan = + executeAndKeepPlan( + "MERGE INTO %s t USING source " + + "ON t.id = source.id " + + "WHEN MATCHED THEN " + + " UPDATE SET salary = -1 ", + commitTarget()); + assertThat(plan.toString()).contains("REBALANCE_PARTITIONS_BY_COL"); + }); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + + if (mode(table) == COPY_ON_WRITE) { + // CoW MERGE would perform a join on `id` and then cluster records by `dep` + // the first shuffle distributes records into 4 shuffle partitions so that rows can be merged + // after existing and new rows are merged, the data is clustered by `dep` + // each task with merged data contains records for the same table partition + // that means there are 4 shuffle blocks, all assigned to the same reducer + // AQE detects that all shuffle blocks are big and processes them in 4 independent tasks + // otherwise, there would be 1 task processing all 4 shuffle blocks + validateProperty(currentSnapshot, SnapshotSummary.ADDED_FILES_PROP, "4"); + } else { + // MoR MERGE would perform a join on `id` and then cluster data based on the partition + // all tasks belong to the same partition and therefore write only 1 shuffle block per task + // that means there are 4 shuffle blocks, all assigned to the same reducer + // AQE detects that all 4 shuffle blocks are big and processes them in 4 separate tasks + // otherwise, there would be 1 task processing 4 shuffle blocks + validateProperty(currentSnapshot, SnapshotSummary.ADDED_DELETE_FILES_PROP, "4"); + } + + assertThat(scalarSql("SELECT COUNT(*) FROM %s WHERE salary = -1", commitTarget())) + .as("Row count must match") + .isEqualTo(400L); + } + + @TestTemplate + public void testMergeConditionSplitIntoTargetPredicateAndJoinCondition() { + createAndInitTable( + "id INT, salary INT, dep STRING, sub_dep STRING", + "PARTITIONED BY (dep, sub_dep)", + "{ \"id\": 1, \"salary\": 100, \"dep\": \"d1\", \"sub_dep\": \"sd1\" }\n" + + "{ \"id\": 6, \"salary\": 600, \"dep\": \"d6\", \"sub_dep\": \"sd6\" }"); + + createOrReplaceView( + "source", + "id INT, salary INT, dep STRING, sub_dep STRING", + "{ \"id\": 1, \"salary\": 101, \"dep\": \"d1\", \"sub_dep\": \"sd1\" }\n" + + "{ \"id\": 2, \"salary\": 200, \"dep\": \"d2\", \"sub_dep\": \"sd2\" }\n" + + "{ \"id\": 3, \"salary\": 300, \"dep\": \"d3\", \"sub_dep\": \"sd3\" }"); + + String query = + String.format( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id AND ((t.dep = 'd1' AND t.sub_dep IN ('sd1', 'sd3')) OR (t.dep = 'd6' AND t.sub_dep IN ('sd2', 'sd6'))) " + + "WHEN MATCHED THEN " + + " UPDATE SET salary = s.salary " + + "WHEN NOT MATCHED THEN " + + " INSERT *", + commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + + if (mode(table) == COPY_ON_WRITE) { + checkJoinAndFilterConditions( + query, + "Join [id], [id], FullOuter", + "((dep = 'd1' AND sub_dep IN ('sd1', 'sd3')) OR (dep = 'd6' AND sub_dep IN ('sd2', 'sd6')))"); + } else { + checkJoinAndFilterConditions( + query, + "Join [id], [id], RightOuter", + "((dep = 'd1' AND sub_dep IN ('sd1', 'sd3')) OR (dep = 'd6' AND sub_dep IN ('sd2', 'sd6')))"); + } + + assertEquals( + "Should have expected rows", + ImmutableList.of( + row(1, 101, "d1", "sd1"), // updated + row(2, 200, "d2", "sd2"), // new + row(3, 300, "d3", "sd3"), // new + row(6, 600, "d6", "sd6")), // existing + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testMergeWithStaticPredicatePushDown() { + createAndInitTable("id BIGINT, dep STRING"); + + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + // add a data file to the 'software' partition + append(tableName, "{ \"id\": 1, \"dep\": \"software\" }"); + createBranchIfNeeded(); + + // add a data file to the 'hr' partition + append(commitTarget(), "{ \"id\": 1, \"dep\": \"hr\" }"); + + Table table = validationCatalog.loadTable(tableIdent); + + Snapshot snapshot = SnapshotUtil.latestSnapshot(table, branch); + String dataFilesCount = snapshot.summary().get(SnapshotSummary.TOTAL_DATA_FILES_PROP); + assertThat(dataFilesCount).as("Must have 2 files before MERGE").isEqualTo("2"); + + createOrReplaceView( + "source", "{ \"id\": 1, \"dep\": \"finance\" }\n" + "{ \"id\": 2, \"dep\": \"hardware\" }"); + + // remove the data file from the 'hr' partition to ensure it is not scanned + withUnavailableFiles( + snapshot.addedDataFiles(table.io()), + () -> { + // disable dynamic pruning and rely only on static predicate pushdown + withSQLConf( + ImmutableMap.of( + SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED().key(), "false", + SQLConf.RUNTIME_ROW_LEVEL_OPERATION_GROUP_FILTER_ENABLED().key(), "false"), + () -> { + sql( + "MERGE INTO %s t USING source " + + "ON t.id == source.id AND t.dep IN ('software') AND source.id < 10 " + + "WHEN MATCHED AND source.id = 1 THEN " + + " UPDATE SET dep = source.dep " + + "WHEN NOT MATCHED THEN " + + " INSERT (dep, id) VALUES (source.dep, source.id)", + commitTarget()); + }); + }); + + ImmutableList expectedRows = + ImmutableList.of( + row(1L, "finance"), // updated + row(1L, "hr"), // kept + row(2L, "hardware") // new + ); + assertEquals( + "Output should match", + expectedRows, + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); + } + + @TestTemplate + public void testMergeIntoEmptyTargetInsertAllNonMatchingRows() { + assumeThat(branch).as("Custom branch does not exist for empty table").isNotEqualTo("test"); + createAndInitTable("id INT, dep STRING"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 3, \"dep\": \"emp-id-3\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN NOT MATCHED THEN " + + " INSERT *", + tableName); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "emp-id-1"), // new + row(2, "emp-id-2"), // new + row(3, "emp-id-3") // new + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testMergeIntoEmptyTargetInsertOnlyMatchingRows() { + assumeThat(branch).as("Custom branch does not exist for empty table").isNotEqualTo("test"); + createAndInitTable("id INT, dep STRING"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 3, \"dep\": \"emp-id-3\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN NOT MATCHED AND (s.id >=2) THEN " + + " INSERT *", + tableName); + + ImmutableList expectedRows = + ImmutableList.of( + row(2, "emp-id-2"), // new + row(3, "emp-id-3") // new + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testMergeWithOnlyUpdateClause() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-six\" }"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET *", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "emp-id-1"), // updated + row(6, "emp-id-six") // kept + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testMergeWithOnlyUpdateNullUnmatchedValues() { + createAndInitTable( + "id INT, value INT", "{ \"id\": 1, \"value\": 2 }\n" + "{ \"id\": 6, \"value\": null }"); + + createOrReplaceView("source", "id INT NOT NULL, value INT", "{ \"id\": 1, \"value\": 100 }\n"); + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.id " + + "WHEN MATCHED THEN " + + " UPDATE SET id=123, value=456", + commitTarget()); + + sql("SELECT * FROM %s", commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(6, null), // kept + row(123, 456)); // updated + + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testMergeWithOnlyUpdateSingleFieldNullUnmatchedValues() { + createAndInitTable( + "id INT, value INT", "{ \"id\": 1, \"value\": 2 }\n" + "{ \"id\": 6, \"value\": null }"); + + createOrReplaceView("source", "id INT NOT NULL, value INT", "{ \"id\": 1, \"value\": 100 }\n"); + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.id " + + "WHEN MATCHED THEN " + + " UPDATE SET id=123", + commitTarget()); + + sql("SELECT * FROM %s", commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(6, null), // kept + row(123, 2)); // updated + + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testMergeWithOnlyDeleteNullUnmatchedValues() { + createAndInitTable( + "id INT, value INT", "{ \"id\": 1, \"value\": 2 }\n" + "{ \"id\": 6, \"value\": null }"); + + createOrReplaceView("source", "id INT NOT NULL, value INT", "{ \"id\": 1, \"value\": 100 }\n"); + sql( + "MERGE INTO %s t USING source s " + "ON t.id == s.id " + "WHEN MATCHED THEN " + "DELETE", + commitTarget()); + + sql("SELECT * FROM %s", commitTarget()); + + ImmutableList expectedRows = ImmutableList.of(row(6, null)); // kept + + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testMergeWithOnlyUpdateClauseAndNullValues() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": null, \"dep\": \"emp-id-one\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-six\" }"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id AND t.id < 3 " + + "WHEN MATCHED THEN " + + " UPDATE SET *", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(null, "emp-id-one"), // kept + row(1, "emp-id-1"), // updated + row(6, "emp-id-six")); // kept + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testMergeWithOnlyDeleteClause() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "emp-id-one") // kept + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testMergeWithMatchedAndNotMatchedClauses() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET * " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "emp-id-1"), // updated + row(2, "emp-id-2") // new + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testMergeWithAllCausesWithExplicitColumnSpecification() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET t.id = s.id, t.dep = s.dep " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT (t.id, t.dep) VALUES (s.id, s.dep)", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "emp-id-1"), // updated + row(2, "emp-id-2") // new + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testMergeWithSourceCTE() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-two\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-3\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 5, \"dep\": \"emp-id-6\" }"); + + sql( + "WITH cte1 AS (SELECT id + 1 AS id, dep FROM source) " + + "MERGE INTO %s AS t USING cte1 AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 2 THEN " + + " UPDATE SET * " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 3 THEN " + + " INSERT *", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(2, "emp-id-2"), // updated + row(3, "emp-id-3") // new + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testMergeWithSourceFromSetOps() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + String derivedSource = + "SELECT * FROM source WHERE id = 2 " + + "UNION ALL " + + "SELECT * FROM source WHERE id = 1 OR id = 6"; + + sql( + "MERGE INTO %s AS t USING (%s) AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET * " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", + commitTarget(), derivedSource); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "emp-id-1"), // updated + row(2, "emp-id-2") // new + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testMergeWithOneMatchingBranchButMultipleSourceRowsForTargetRow() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 1, \"state\": \"on\" }\n" + + "{ \"id\": 1, \"state\": \"off\" }\n" + + "{ \"id\": 10, \"state\": \"on\" }"); + + String errorMsg = + "MERGE statement matched a single row from the target table with multiple rows of the source table."; + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED THEN " + + " INSERT (id, dep) VALUES (s.id, 'unknown')", + commitTarget())) + .isInstanceOf(SparkRuntimeException.class) + .hasMessageContaining(errorMsg); + + assertEquals( + "Target should be unchanged", + ImmutableList.of(row(1, "emp-id-one"), row(6, "emp-id-6")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @TestTemplate + public void testMergeWithMultipleUpdatesForTargetRowSmallTargetLargeSource() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + List sourceIds = Lists.newArrayList(); + for (int i = 0; i < 10_000; i++) { + sourceIds.add(i); + } + Dataset ds = spark.createDataset(sourceIds, Encoders.INT()); + ds.union(ds).createOrReplaceTempView("source"); + + String errorMsg = + "MERGE statement matched a single row from the target table with multiple rows of the source table."; + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.value " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET id = 10 " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.value = 2 THEN " + + " INSERT (id, dep) VALUES (s.value, null)", + commitTarget())) + .isInstanceOf(SparkRuntimeException.class) + .hasMessageContaining(errorMsg); + + assertEquals( + "Target should be unchanged", + ImmutableList.of(row(1, "emp-id-one"), row(6, "emp-id-6")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @TestTemplate + public void + testMergeWithMultipleUpdatesForTargetRowSmallTargetLargeSourceEnabledHashShuffleJoin() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + List sourceIds = Lists.newArrayList(); + for (int i = 0; i < 10_000; i++) { + sourceIds.add(i); + } + Dataset ds = spark.createDataset(sourceIds, Encoders.INT()); + ds.union(ds).createOrReplaceTempView("source"); + + withSQLConf( + ImmutableMap.of(SQLConf.PREFER_SORTMERGEJOIN().key(), "false"), + () -> { + String errorMsg = + "MERGE statement matched a single row from the target table with multiple rows of the source table."; + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.value " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET id = 10 " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.value = 2 THEN " + + " INSERT (id, dep) VALUES (s.value, null)", + commitTarget())) + .isInstanceOf(SparkRuntimeException.class) + .hasMessageContaining(errorMsg); + }); + + assertEquals( + "Target should be unchanged", + ImmutableList.of(row(1, "emp-id-one"), row(6, "emp-id-6")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @TestTemplate + public void testMergeWithMultipleUpdatesForTargetRowSmallTargetLargeSourceNoEqualityCondition() { + createAndInitTable("id INT, dep STRING", "{ \"id\": 1, \"dep\": \"emp-id-one\" }"); + + List sourceIds = Lists.newArrayList(); + for (int i = 0; i < 10_000; i++) { + sourceIds.add(i); + } + Dataset ds = spark.createDataset(sourceIds, Encoders.INT()); + ds.union(ds).createOrReplaceTempView("source"); + + withSQLConf( + ImmutableMap.of(SQLConf.PREFER_SORTMERGEJOIN().key(), "false"), + () -> { + String errorMsg = + "MERGE statement matched a single row from the target table with multiple rows of the source table."; + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id > s.value " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET id = 10 " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.value = 2 THEN " + + " INSERT (id, dep) VALUES (s.value, null)", + commitTarget())) + .isInstanceOf(SparkRuntimeException.class) + .hasMessageContaining(errorMsg); + }); + + assertEquals( + "Target should be unchanged", + ImmutableList.of(row(1, "emp-id-one")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @TestTemplate + public void testMergeWithMultipleUpdatesForTargetRowSmallTargetLargeSourceNoNotMatchedActions() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + List sourceIds = Lists.newArrayList(); + for (int i = 0; i < 10_000; i++) { + sourceIds.add(i); + } + Dataset ds = spark.createDataset(sourceIds, Encoders.INT()); + ds.union(ds).createOrReplaceTempView("source"); + + String errorMsg = + "MERGE statement matched a single row from the target table with multiple rows of the source table."; + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.value " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET id = 10 " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE", + commitTarget())) + .isInstanceOf(SparkRuntimeException.class) + .hasMessageContaining(errorMsg); + assertEquals( + "Target should be unchanged", + ImmutableList.of(row(1, "emp-id-one"), row(6, "emp-id-6")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @TestTemplate + public void + testMergeWithMultipleUpdatesForTargetRowSmallTargetLargeSourceNoNotMatchedActionsNoEqualityCondition() { + createAndInitTable("id INT, dep STRING", "{ \"id\": 1, \"dep\": \"emp-id-one\" }"); + + List sourceIds = Lists.newArrayList(); + for (int i = 0; i < 10_000; i++) { + sourceIds.add(i); + } + Dataset ds = spark.createDataset(sourceIds, Encoders.INT()); + ds.union(ds).createOrReplaceTempView("source"); + + String errorMsg = + "MERGE statement matched a single row from the target table with multiple rows of the source table."; + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id > s.value " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET id = 10 " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE", + commitTarget())) + .isInstanceOf(SparkRuntimeException.class) + .hasMessageContaining(errorMsg); + + assertEquals( + "Target should be unchanged", + ImmutableList.of(row(1, "emp-id-one")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @TestTemplate + public void testMergeWithMultipleUpdatesForTargetRow() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + String errorMsg = + "MERGE statement matched a single row from the target table with multiple rows of the source table."; + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET * " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", + commitTarget())) + .isInstanceOf(SparkRuntimeException.class) + .hasMessageContaining(errorMsg); + + assertEquals( + "Target should be unchanged", + ImmutableList.of(row(1, "emp-id-one"), row(6, "emp-id-6")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @TestTemplate + public void testMergeWithUnconditionalDelete() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(2, "emp-id-2") // new + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testMergeWithSingleConditionalDelete() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + String errorMsg = + "MERGE statement matched a single row from the target table with multiple rows of the source table."; + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", + commitTarget())) + .isInstanceOf(SparkRuntimeException.class) + .hasMessageContaining(errorMsg); + + assertEquals( + "Target should be unchanged", + ImmutableList.of(row(1, "emp-id-one"), row(6, "emp-id-6")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @TestTemplate + public void testMergeWithIdentityTransform() { + for (DistributionMode mode : DistributionMode.values()) { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD identity(dep)", tableName); + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, WRITE_DISTRIBUTION_MODE, mode.modeName()); + + append( + tableName, + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + createBranchIfNeeded(); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET * " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "emp-id-1"), // updated + row(2, "emp-id-2") // new + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + + removeTables(); + } + } + + @TestTemplate + public void testMergeWithDaysTransform() { + for (DistributionMode mode : DistributionMode.values()) { + createAndInitTable("id INT, ts TIMESTAMP"); + sql("ALTER TABLE %s ADD PARTITION FIELD days(ts)", tableName); + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, WRITE_DISTRIBUTION_MODE, mode.modeName()); + + append( + tableName, + "id INT, ts TIMESTAMP", + "{ \"id\": 1, \"ts\": \"2000-01-01 00:00:00\" }\n" + + "{ \"id\": 6, \"ts\": \"2000-01-06 00:00:00\" }"); + createBranchIfNeeded(); + + createOrReplaceView( + "source", + "id INT, ts TIMESTAMP", + "{ \"id\": 2, \"ts\": \"2001-01-02 00:00:00\" }\n" + + "{ \"id\": 1, \"ts\": \"2001-01-01 00:00:00\" }\n" + + "{ \"id\": 6, \"ts\": \"2001-01-06 00:00:00\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET * " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "2001-01-01 00:00:00"), // updated + row(2, "2001-01-02 00:00:00") // new + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT id, CAST(ts AS STRING) FROM %s ORDER BY id", selectTarget())); + + removeTables(); + } + } + + @TestTemplate + public void testMergeWithBucketTransform() { + for (DistributionMode mode : DistributionMode.values()) { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD bucket(2, dep)", tableName); + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, WRITE_DISTRIBUTION_MODE, mode.modeName()); + + append( + tableName, + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + createBranchIfNeeded(); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET * " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "emp-id-1"), // updated + row(2, "emp-id-2") // new + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + + removeTables(); + } + } + + @TestTemplate + public void testMergeWithTruncateTransform() { + for (DistributionMode mode : DistributionMode.values()) { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD truncate(dep, 2)", tableName); + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, WRITE_DISTRIBUTION_MODE, mode.modeName()); + + append( + tableName, + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + createBranchIfNeeded(); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET * " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "emp-id-1"), // updated + row(2, "emp-id-2") // new + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + + removeTables(); + } + } + + @TestTemplate + public void testMergeIntoPartitionedAndOrderedTable() { + for (DistributionMode mode : DistributionMode.values()) { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + sql("ALTER TABLE %s WRITE ORDERED BY (id)", tableName); + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, WRITE_DISTRIBUTION_MODE, mode.modeName()); + + append( + tableName, + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + createBranchIfNeeded(); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET * " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "emp-id-1"), // updated + row(2, "emp-id-2") // new + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + + removeTables(); + } + } + + @TestTemplate + public void testSelfMerge() { + createAndInitTable( + "id INT, v STRING", "{ \"id\": 1, \"v\": \"v1\" }\n" + "{ \"id\": 2, \"v\": \"v2\" }"); + + sql( + "MERGE INTO %s t USING %s s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET v = 'x' " + + "WHEN NOT MATCHED THEN " + + " INSERT *", + commitTarget(), commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "x"), // updated + row(2, "v2") // kept + ); + assertEquals( + "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testSelfMergeWithCaching() { + createAndInitTable( + "id INT, v STRING", "{ \"id\": 1, \"v\": \"v1\" }\n" + "{ \"id\": 2, \"v\": \"v2\" }"); + + sql("CACHE TABLE %s", tableName); + + sql( + "MERGE INTO %s t USING %s s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET v = 'x' " + + "WHEN NOT MATCHED THEN " + + " INSERT *", + commitTarget(), commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "x"), // updated + row(2, "v2") // kept + ); + assertEquals( + "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", commitTarget())); + } + + @TestTemplate + public void testMergeWithSourceAsSelfSubquery() { + createAndInitTable( + "id INT, v STRING", "{ \"id\": 1, \"v\": \"v1\" }\n" + "{ \"id\": 2, \"v\": \"v2\" }"); + + createOrReplaceView("source", Arrays.asList(1, null), Encoders.INT()); + + sql( + "MERGE INTO %s t USING (SELECT id AS value FROM %s r JOIN source ON r.id = source.value) s " + + "ON t.id == s.value " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET v = 'x' " + + "WHEN NOT MATCHED THEN " + + " INSERT (v, id) VALUES ('invalid', -1) ", + commitTarget(), commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "x"), // updated + row(2, "v2") // kept + ); + assertEquals( + "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public synchronized void testMergeWithSerializableIsolation() throws InterruptedException { + // cannot run tests with concurrency for Hadoop tables without atomic renames + assumeThat(catalogName).isNotEqualToIgnoringCase("testhadoop"); + // if caching is off, the table is eagerly refreshed during runtime filtering + // this can cause a validation exception as concurrent changes would be visible + assumeThat(cachingCatalogEnabled()).isTrue(); + + createAndInitTable("id INT, dep STRING"); + createOrReplaceView("source", Collections.singletonList(1), Encoders.INT()); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, MERGE_ISOLATION_LEVEL, "serializable"); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); + + ExecutorService executorService = + MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) Executors.newFixedThreadPool(2)); + + AtomicInteger barrier = new AtomicInteger(0); + AtomicBoolean shouldAppend = new AtomicBoolean(true); + + // merge thread + Future mergeFuture = + executorService.submit( + () -> { + for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) { + int currentNumOperations = numOperations; + Awaitility.await() + .pollInterval(10, TimeUnit.MILLISECONDS) + .atMost(5, TimeUnit.SECONDS) + .until(() -> barrier.get() >= currentNumOperations * 2); + + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.value " + + "WHEN MATCHED THEN " + + " UPDATE SET dep = 'x'", + commitTarget()); + + barrier.incrementAndGet(); + } + }); + + // append thread + Future appendFuture = + executorService.submit( + () -> { + // load the table via the validation catalog to use another table instance + Table table = validationCatalog.loadTable(tableIdent); + + GenericRecord record = GenericRecord.create(table.schema()); + record.set(0, 1); // id + record.set(1, "hr"); // dep + + for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) { + int currentNumOperations = numOperations; + Awaitility.await() + .pollInterval(10, TimeUnit.MILLISECONDS) + .atMost(5, TimeUnit.SECONDS) + .until(() -> !shouldAppend.get() || barrier.get() >= currentNumOperations * 2); + + if (!shouldAppend.get()) { + return; + } + + for (int numAppends = 0; numAppends < 5; numAppends++) { + DataFile dataFile = writeDataFile(table, ImmutableList.of(record)); + AppendFiles appendFiles = table.newFastAppend().appendFile(dataFile); + if (branch != null) { + appendFiles.toBranch(branch); + } + appendFiles.commit(); + } + + barrier.incrementAndGet(); + } + }); + + try { + assertThatThrownBy(mergeFuture::get) + .isInstanceOf(ExecutionException.class) + .cause() + .isInstanceOf(ValidationException.class) + .hasMessageContaining("Found conflicting files that can contain"); + } finally { + shouldAppend.set(false); + appendFuture.cancel(true); + } + + executorService.shutdown(); + assertThat(executorService.awaitTermination(2, TimeUnit.MINUTES)).as("Timeout").isTrue(); + } + + @TestTemplate + public synchronized void testMergeWithSnapshotIsolation() + throws InterruptedException, ExecutionException { + // cannot run tests with concurrency for Hadoop tables without atomic renames + assumeThat(catalogName).isNotEqualToIgnoringCase("testhadoop"); + // if caching is off, the table is eagerly refreshed during runtime filtering + // this can cause a validation exception as concurrent changes would be visible + assumeThat(cachingCatalogEnabled()).isTrue(); + + createAndInitTable("id INT, dep STRING"); + createOrReplaceView("source", Collections.singletonList(1), Encoders.INT()); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, MERGE_ISOLATION_LEVEL, "snapshot"); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); + + ExecutorService executorService = + MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) Executors.newFixedThreadPool(2)); + + AtomicInteger barrier = new AtomicInteger(0); + AtomicBoolean shouldAppend = new AtomicBoolean(true); + + // merge thread + Future mergeFuture = + executorService.submit( + () -> { + for (int numOperations = 0; numOperations < 20; numOperations++) { + int currentNumOperations = numOperations; + Awaitility.await() + .pollInterval(10, TimeUnit.MILLISECONDS) + .atMost(5, TimeUnit.SECONDS) + .until(() -> barrier.get() >= currentNumOperations * 2); + + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.value " + + "WHEN MATCHED THEN " + + " UPDATE SET dep = 'x'", + commitTarget()); + + barrier.incrementAndGet(); + } + }); + + // append thread + Future appendFuture = + executorService.submit( + () -> { + // load the table via the validation catalog to use another table instance for inserts + Table table = validationCatalog.loadTable(tableIdent); + + GenericRecord record = GenericRecord.create(table.schema()); + record.set(0, 1); // id + record.set(1, "hr"); // dep + + for (int numOperations = 0; numOperations < 20; numOperations++) { + int currentNumOperations = numOperations; + Awaitility.await() + .pollInterval(10, TimeUnit.MILLISECONDS) + .atMost(5, TimeUnit.SECONDS) + .until(() -> !shouldAppend.get() || barrier.get() >= currentNumOperations * 2); + + if (!shouldAppend.get()) { + return; + } + + for (int numAppends = 0; numAppends < 5; numAppends++) { + DataFile dataFile = writeDataFile(table, ImmutableList.of(record)); + AppendFiles appendFiles = table.newFastAppend().appendFile(dataFile); + if (branch != null) { + appendFiles.toBranch(branch); + } + + appendFiles.commit(); + } + + barrier.incrementAndGet(); + } + }); + + try { + mergeFuture.get(); + } finally { + shouldAppend.set(false); + appendFuture.cancel(true); + } + + executorService.shutdown(); + assertThat(executorService.awaitTermination(2, TimeUnit.MINUTES)).as("Timeout").isTrue(); + } + + @TestTemplate + public void testMergeWithExtraColumnsInSource() { + createAndInitTable( + "id INT, v STRING", "{ \"id\": 1, \"v\": \"v1\" }\n" + "{ \"id\": 2, \"v\": \"v2\" }"); + createOrReplaceView( + "source", + "{ \"id\": 1, \"extra_col\": -1, \"v\": \"v1_1\" }\n" + + "{ \"id\": 3, \"extra_col\": -1, \"v\": \"v3\" }\n" + + "{ \"id\": 4, \"extra_col\": -1, \"v\": \"v4\" }"); + + sql( + "MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN MATCHED THEN " + + " UPDATE SET v = source.v " + + "WHEN NOT MATCHED THEN " + + " INSERT (v, id) VALUES (source.v, source.id)", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "v1_1"), // new + row(2, "v2"), // kept + row(3, "v3"), // new + row(4, "v4") // new + ); + assertEquals( + "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testMergeWithNullsInTargetAndSource() { + createAndInitTable( + "id INT, v STRING", "{ \"id\": null, \"v\": \"v1\" }\n" + "{ \"id\": 2, \"v\": \"v2\" }"); + + createOrReplaceView( + "source", "{ \"id\": null, \"v\": \"v1_1\" }\n" + "{ \"id\": 4, \"v\": \"v4\" }"); + + sql( + "MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN MATCHED THEN " + + " UPDATE SET v = source.v " + + "WHEN NOT MATCHED THEN " + + " INSERT (v, id) VALUES (source.v, source.id)", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(null, "v1"), // kept + row(null, "v1_1"), // new + row(2, "v2"), // kept + row(4, "v4") // new + ); + assertEquals( + "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY v", selectTarget())); + } + + @TestTemplate + public void testMergeWithNullSafeEquals() { + createAndInitTable( + "id INT, v STRING", "{ \"id\": null, \"v\": \"v1\" }\n" + "{ \"id\": 2, \"v\": \"v2\" }"); + + createOrReplaceView( + "source", "{ \"id\": null, \"v\": \"v1_1\" }\n" + "{ \"id\": 4, \"v\": \"v4\" }"); + + sql( + "MERGE INTO %s t USING source " + + "ON t.id <=> source.id " + + "WHEN MATCHED THEN " + + " UPDATE SET v = source.v " + + "WHEN NOT MATCHED THEN " + + " INSERT (v, id) VALUES (source.v, source.id)", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(null, "v1_1"), // updated + row(2, "v2"), // kept + row(4, "v4") // new + ); + assertEquals( + "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY v", selectTarget())); + } + + @TestTemplate + public void testMergeWithNullCondition() { + createAndInitTable( + "id INT, v STRING", "{ \"id\": null, \"v\": \"v1\" }\n" + "{ \"id\": 2, \"v\": \"v2\" }"); + + createOrReplaceView( + "source", "{ \"id\": null, \"v\": \"v1_1\" }\n" + "{ \"id\": 2, \"v\": \"v2_2\" }"); + + sql( + "MERGE INTO %s t USING source " + + "ON t.id == source.id AND NULL " + + "WHEN MATCHED THEN " + + " UPDATE SET v = source.v " + + "WHEN NOT MATCHED THEN " + + " INSERT (v, id) VALUES (source.v, source.id)", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(null, "v1"), // kept + row(null, "v1_1"), // new + row(2, "v2"), // kept + row(2, "v2_2") // new + ); + assertEquals( + "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY v", selectTarget())); + } + + @TestTemplate + public void testMergeWithNullActionConditions() { + createAndInitTable( + "id INT, v STRING", "{ \"id\": 1, \"v\": \"v1\" }\n" + "{ \"id\": 2, \"v\": \"v2\" }"); + + createOrReplaceView( + "source", + "{ \"id\": 1, \"v\": \"v1_1\" }\n" + + "{ \"id\": 2, \"v\": \"v2_2\" }\n" + + "{ \"id\": 3, \"v\": \"v3_3\" }"); + + // all conditions are NULL and will never match any rows + sql( + "MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN MATCHED AND source.id = 1 AND NULL THEN " + + " UPDATE SET v = source.v " + + "WHEN MATCHED AND source.v = 'v1_1' AND NULL THEN " + + " DELETE " + + "WHEN NOT MATCHED AND source.id = 3 AND NULL THEN " + + " INSERT (v, id) VALUES (source.v, source.id)", + commitTarget()); + + ImmutableList expectedRows1 = + ImmutableList.of( + row(1, "v1"), // kept + row(2, "v2") // kept + ); + assertEquals( + "Output should match", expectedRows1, sql("SELECT * FROM %s ORDER BY v", selectTarget())); + + // only the update and insert conditions are NULL + sql( + "MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN MATCHED AND source.id = 1 AND NULL THEN " + + " UPDATE SET v = source.v " + + "WHEN MATCHED AND source.v = 'v1_1' THEN " + + " DELETE " + + "WHEN NOT MATCHED AND source.id = 3 AND NULL THEN " + + " INSERT (v, id) VALUES (source.v, source.id)", + commitTarget()); + + ImmutableList expectedRows2 = + ImmutableList.of( + row(2, "v2") // kept + ); + assertEquals( + "Output should match", expectedRows2, sql("SELECT * FROM %s ORDER BY v", selectTarget())); + } + + @TestTemplate + public void testMergeWithMultipleMatchingActions() { + createAndInitTable( + "id INT, v STRING", "{ \"id\": 1, \"v\": \"v1\" }\n" + "{ \"id\": 2, \"v\": \"v2\" }"); + + createOrReplaceView( + "source", "{ \"id\": 1, \"v\": \"v1_1\" }\n" + "{ \"id\": 2, \"v\": \"v2_2\" }"); + + // the order of match actions is important in this case + sql( + "MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN MATCHED AND source.id = 1 THEN " + + " UPDATE SET v = source.v " + + "WHEN MATCHED AND source.v = 'v1_1' THEN " + + " DELETE " + + "WHEN NOT MATCHED THEN " + + " INSERT (v, id) VALUES (source.v, source.id)", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "v1_1"), // updated (also matches the delete cond but update is first) + row(2, "v2") // kept (matches neither the update nor the delete cond) + ); + assertEquals( + "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY v", selectTarget())); + } + + @TestTemplate + public void testMergeWithMultipleRowGroupsParquet() throws NoSuchTableException { + assumeThat(fileFormat).isEqualTo(FileFormat.PARQUET); + + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", + tableName, PARQUET_ROW_GROUP_SIZE_BYTES, 100); + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", tableName, SPLIT_SIZE, 100); + + createOrReplaceView("source", Collections.singletonList(1), Encoders.INT()); + + List ids = Lists.newArrayListWithCapacity(200); + for (int id = 1; id <= 200; id++) { + ids.add(id); + } + Dataset df = + spark + .createDataset(ids, Encoders.INT()) + .withColumnRenamed("value", "id") + .withColumn("dep", lit("hr")); + df.coalesce(1).writeTo(tableName).append(); + createBranchIfNeeded(); + + assertThat(spark.table(commitTarget()).count()).isEqualTo(200); + + // update a record from one of two row groups and copy over the second one + sql( + "MERGE INTO %s t USING source " + + "ON t.id == source.value " + + "WHEN MATCHED THEN " + + " UPDATE SET dep = 'x'", + commitTarget()); + + assertThat(spark.table(commitTarget()).count()).isEqualTo(200); + } + + @TestTemplate + public void testMergeInsertOnly() { + createAndInitTable( + "id STRING, v STRING", + "{ \"id\": \"a\", \"v\": \"v1\" }\n" + "{ \"id\": \"b\", \"v\": \"v2\" }"); + createOrReplaceView( + "source", + "{ \"id\": \"a\", \"v\": \"v1_1\" }\n" + + "{ \"id\": \"a\", \"v\": \"v1_2\" }\n" + + "{ \"id\": \"c\", \"v\": \"v3\" }\n" + + "{ \"id\": \"d\", \"v\": \"v4_1\" }\n" + + "{ \"id\": \"d\", \"v\": \"v4_2\" }"); + + sql( + "MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN NOT MATCHED THEN " + + " INSERT *", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row("a", "v1"), // kept + row("b", "v2"), // kept + row("c", "v3"), // new + row("d", "v4_1"), // new + row("d", "v4_2") // new + ); + assertEquals( + "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testMergeInsertOnlyWithCondition() { + createAndInitTable("id INTEGER, v INTEGER", "{ \"id\": 1, \"v\": 1 }"); + createOrReplaceView( + "source", + "{ \"id\": 1, \"v\": 11, \"is_new\": true }\n" + + "{ \"id\": 2, \"v\": 21, \"is_new\": true }\n" + + "{ \"id\": 2, \"v\": 22, \"is_new\": false }"); + + // validate assignments are reordered to match the table attrs + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.id " + + "WHEN NOT MATCHED AND is_new = TRUE THEN " + + " INSERT (v, id) VALUES (s.v + 100, s.id)", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, 1), // kept + row(2, 121) // new + ); + assertEquals( + "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testMergeAlignsUpdateAndInsertActions() { + createAndInitTable("id INT, a INT, b STRING", "{ \"id\": 1, \"a\": 2, \"b\": \"str\" }"); + createOrReplaceView( + "source", + "{ \"id\": 1, \"c1\": -2, \"c2\": \"new_str_1\" }\n" + + "{ \"id\": 2, \"c1\": -20, \"c2\": \"new_str_2\" }"); + + sql( + "MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN MATCHED THEN " + + " UPDATE SET b = c2, a = c1, t.id = source.id " + + "WHEN NOT MATCHED THEN " + + " INSERT (b, a, id) VALUES (c2, c1, id)", + commitTarget()); + + assertEquals( + "Output should match", + ImmutableList.of(row(1, -2, "new_str_1"), row(2, -20, "new_str_2")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testMergeMixedCaseAlignsUpdateAndInsertActions() { + createAndInitTable("id INT, a INT, b STRING", "{ \"id\": 1, \"a\": 2, \"b\": \"str\" }"); + createOrReplaceView( + "source", + "{ \"id\": 1, \"c1\": -2, \"c2\": \"new_str_1\" }\n" + + "{ \"id\": 2, \"c1\": -20, \"c2\": \"new_str_2\" }"); + + sql( + "MERGE INTO %s t USING source " + + "ON t.iD == source.Id " + + "WHEN MATCHED THEN " + + " UPDATE SET B = c2, A = c1, t.Id = source.ID " + + "WHEN NOT MATCHED THEN " + + " INSERT (b, A, iD) VALUES (c2, c1, id)", + commitTarget()); + + assertEquals( + "Output should match", + ImmutableList.of(row(1, -2, "new_str_1"), row(2, -20, "new_str_2")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + + assertEquals( + "Output should match", + ImmutableList.of(row(1, -2, "new_str_1")), + sql("SELECT * FROM %s WHERE id = 1 ORDER BY id", selectTarget())); + assertEquals( + "Output should match", + ImmutableList.of(row(2, -20, "new_str_2")), + sql("SELECT * FROM %s WHERE b = 'new_str_2'ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testMergeUpdatesNestedStructFields() { + createAndInitTable( + "id INT, s STRUCT,m:MAP>>", + "{ \"id\": 1, \"s\": { \"c1\": 2, \"c2\": { \"a\": [1,2], \"m\": { \"a\": \"b\"} } } } }"); + createOrReplaceView("source", "{ \"id\": 1, \"c1\": -2 }"); + + // update primitive, array, map columns inside a struct + sql( + "MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN MATCHED THEN " + + " UPDATE SET t.s.c1 = source.c1, t.s.c2.a = array(-1, -2), t.s.c2.m = map('k', 'v')", + commitTarget()); + + assertEquals( + "Output should match", + ImmutableList.of(row(1, row(-2, row(ImmutableList.of(-1, -2), ImmutableMap.of("k", "v"))))), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + + // set primitive, array, map columns to NULL (proper casts should be in place) + sql( + "MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN MATCHED THEN " + + " UPDATE SET t.s.c1 = NULL, t.s.c2 = NULL", + commitTarget()); + + assertEquals( + "Output should match", + ImmutableList.of(row(1, row(null, null))), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + + // update all fields in a struct + sql( + "MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN MATCHED THEN " + + " UPDATE SET t.s = named_struct('c1', 100, 'c2', named_struct('a', array(1), 'm', map('x', 'y')))", + commitTarget()); + + assertEquals( + "Output should match", + ImmutableList.of(row(1, row(100, row(ImmutableList.of(1), ImmutableMap.of("x", "y"))))), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testMergeWithInferredCasts() { + createAndInitTable("id INT, s STRING", "{ \"id\": 1, \"s\": \"value\" }"); + createOrReplaceView("source", "{ \"id\": 1, \"c1\": -2}"); + + // -2 in source should be casted to "-2" in target + sql( + "MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN MATCHED THEN " + + " UPDATE SET t.s = source.c1", + commitTarget()); + + assertEquals( + "Output should match", + ImmutableList.of(row(1, "-2")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testMergeModifiesNullStruct() { + createAndInitTable("id INT, s STRUCT", "{ \"id\": 1, \"s\": null }"); + createOrReplaceView("source", "{ \"id\": 1, \"n1\": -10 }"); + + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.id " + + "WHEN MATCHED THEN " + + " UPDATE SET t.s.n1 = s.n1", + commitTarget()); + + assertEquals( + "Output should match", + ImmutableList.of(row(1, row(-10, null))), + sql("SELECT * FROM %s", selectTarget())); + } + + @TestTemplate + public void testMergeRefreshesRelationCache() { + createAndInitTable("id INT, name STRING", "{ \"id\": 1, \"name\": \"n1\" }"); + createOrReplaceView("source", "{ \"id\": 1, \"name\": \"n2\" }"); + + Dataset query = spark.sql("SELECT name FROM " + commitTarget()); + query.createOrReplaceTempView("tmp"); + + spark.sql("CACHE TABLE tmp"); + + assertEquals( + "View should have correct data", ImmutableList.of(row("n1")), sql("SELECT * FROM tmp")); + + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.id " + + "WHEN MATCHED THEN " + + " UPDATE SET t.name = s.name", + commitTarget()); + + assertEquals( + "View should have correct data", ImmutableList.of(row("n2")), sql("SELECT * FROM tmp")); + + spark.sql("UNCACHE TABLE tmp"); + } + + @TestTemplate + public void testMergeWithMultipleNotMatchedActions() { + createAndInitTable("id INT, dep STRING", "{ \"id\": 0, \"dep\": \"emp-id-0\" }"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 3, \"dep\": \"emp-id-3\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN NOT MATCHED AND s.id = 1 THEN " + + " INSERT (dep, id) VALUES (s.dep, -1)" + + "WHEN NOT MATCHED THEN " + + " INSERT *", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(-1, "emp-id-1"), // new + row(0, "emp-id-0"), // kept + row(2, "emp-id-2"), // new + row(3, "emp-id-3") // new + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testMergeWithMultipleConditionalNotMatchedActions() { + createAndInitTable("id INT, dep STRING", "{ \"id\": 0, \"dep\": \"emp-id-0\" }"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 3, \"dep\": \"emp-id-3\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN NOT MATCHED AND s.id = 1 THEN " + + " INSERT (dep, id) VALUES (s.dep, -1)" + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(-1, "emp-id-1"), // new + row(0, "emp-id-0"), // kept + row(2, "emp-id-2") // new + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testMergeResolvesColumnsByName() { + createAndInitTable( + "id INT, badge INT, dep STRING", + "{ \"id\": 1, \"badge\": 1000, \"dep\": \"emp-id-one\" }\n" + + "{ \"id\": 6, \"badge\": 6000, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView( + "source", + "badge INT, id INT, dep STRING", + "{ \"badge\": 1001, \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"badge\": 6006, \"id\": 6, \"dep\": \"emp-id-6\" }\n" + + "{ \"badge\": 7007, \"id\": 7, \"dep\": \"emp-id-7\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED THEN " + + " UPDATE SET * " + + "WHEN NOT MATCHED THEN " + + " INSERT * ", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, 1001, "emp-id-1"), // updated + row(6, 6006, "emp-id-6"), // updated + row(7, 7007, "emp-id-7") // new + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT id, badge, dep FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testMergeShouldResolveWhenThereAreNoUnresolvedExpressionsOrColumns() { + // ensures that MERGE INTO will resolve into the correct action even if no columns + // or otherwise unresolved expressions exist in the query (testing SPARK-34962) + createAndInitTable("id INT, dep STRING"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 3, \"dep\": \"emp-id-3\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON 1 != 1 " + + "WHEN MATCHED THEN " + + " UPDATE SET * " + + "WHEN NOT MATCHED THEN " + + " INSERT *", + tableName); + createBranchIfNeeded(); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "emp-id-1"), // new + row(2, "emp-id-2"), // new + row(3, "emp-id-3") // new + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testMergeWithTableWithNonNullableColumn() { + createAndInitTable( + "id INT NOT NULL, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView( + "source", + "id INT NOT NULL, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET * " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "emp-id-1"), // updated + row(2, "emp-id-2")); // new + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testMergeWithNonExistingColumns() { + createAndInitTable( + "id INT, c STRUCT>", + "{ \"id\": 1, \"c\": { \"n1\": 2, \"n2\": { \"dn1\": 3, \"dn2\": 4 } } }"); + createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }"); + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.invalid_col = s.c2", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + "A column, variable, or function parameter with name `t`.`invalid_col` cannot be resolved"); + + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.c.n2.invalid_col = s.c2", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("No such struct field `invalid_col`"); + + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.c.n2.dn1 = s.c2 " + + "WHEN NOT MATCHED THEN " + + " INSERT (id, invalid_col) VALUES (s.c1, null)", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + "A column, variable, or function parameter with name `invalid_col` cannot be resolved"); + } + + @TestTemplate + public void testMergeWithInvalidColumnsInInsert() { + createAndInitTable( + "id INT, c STRUCT> NOT NULL", + "{ \"id\": 1, \"c\": { \"n1\": 2, \"n2\": { \"dn1\": 3, \"dn2\": 4 } } }"); + createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }"); + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.c.n2.dn1 = s.c2 " + + "WHEN NOT MATCHED THEN " + + " INSERT (id, c.n2) VALUES (s.c1, null)", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("INSERT assignment keys cannot be nested fields"); + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.c.n2.dn1 = s.c2 " + + "WHEN NOT MATCHED THEN " + + " INSERT (id, id) VALUES (s.c1, null)", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Multiple assignments for 'id'"); + + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN NOT MATCHED THEN " + + " INSERT (id) VALUES (s.c1)", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("No assignment for 'c'"); + } + + @TestTemplate + public void testMergeWithMissingOptionalColumnsInInsert() { + createAndInitTable("id INT, value LONG", "{ \"id\": 1, \"value\": 100}"); + createOrReplaceView("source", "{ \"c1\": 2, \"c2\": 200 }"); + + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN NOT MATCHED THEN " + + " INSERT (id) VALUES (s.c1)", + commitTarget()); + + assertEquals( + "Should have expected rows", + ImmutableList.of( + row(1, 100L), // existing + row(2, null)), // new + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testMergeWithInvalidUpdates() { + createAndInitTable( + "id INT, a ARRAY>, m MAP", + "{ \"id\": 1, \"a\": [ { \"c1\": 2, \"c2\": 3 } ], \"m\": { \"k\": \"v\"} }"); + createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }"); + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.a.c1 = s.c2", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Updating nested fields is only supported for StructType"); + + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.m.key = 'new_key'", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Updating nested fields is only supported for StructType"); + } + + @TestTemplate + public void testMergeWithConflictingUpdates() { + createAndInitTable( + "id INT, c STRUCT>", + "{ \"id\": 1, \"c\": { \"n1\": 2, \"n2\": { \"dn1\": 3, \"dn2\": 4 } } }"); + createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }"); + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.id = 1, t.c.n1 = 2, t.id = 2", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Multiple assignments for 'id"); + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.c.n1 = 1, t.id = 2, t.c.n1 = 2", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Multiple assignments for 'c.n1'"); + + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET c.n1 = 1, c = named_struct('n1', 1, 'n2', named_struct('dn1', 1, 'dn2', 2))", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Conflicting assignments for 'c'"); + } + + @TestTemplate + public void testMergeWithInvalidAssignmentsAnsi() { + createAndInitTable( + "id INT NOT NULL, s STRUCT> NOT NULL", + "{ \"id\": 1, \"s\": { \"n1\": 2, \"n2\": { \"dn1\": 3, \"dn2\": 4 } } }"); + createOrReplaceView( + "source", + "c1 INT, c2 STRUCT NOT NULL, c3 STRING NOT NULL, c4 STRUCT", + "{ \"c1\": 1, \"c2\": { \"n1\" : 1 }, \"c3\" : 'str', \"c4\": { \"dn3\": 1, \"dn1\": 2 } }"); + + withSQLConf( + ImmutableMap.of(SQLConf.STORE_ASSIGNMENT_POLICY().key(), "ansi"), + () -> { + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.s.n1 = NULL", + commitTarget())) + .isInstanceOf(SparkRuntimeException.class) + .hasMessageContaining("NULL value appeared in non-nullable field"); + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.s = s.c2", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Cannot find data for the output column `s`.`n2`"); + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.s.n1 = s.c3", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Cannot safely cast `s`.`n1` \"STRING\" to \"INT\"."); + + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.s.n2 = s.c4", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Cannot find data for the output column `s`.`n2`.`dn2`"); + }); + } + + @TestTemplate + public void testMergeWithInvalidAssignmentsStrict() { + createAndInitTable( + "id INT NOT NULL, s STRUCT> NOT NULL", + "{ \"id\": 1, \"s\": { \"n1\": 2, \"n2\": { \"dn1\": 3, \"dn2\": 4 } } }"); + createOrReplaceView( + "source", + "c1 INT, c2 STRUCT NOT NULL, c3 STRING NOT NULL, c4 STRUCT", + "{ \"c1\": 1, \"c2\": { \"n1\" : 1 }, \"c3\" : 'str', \"c4\": { \"dn3\": 1, \"dn1\": 2 } }"); + + withSQLConf( + ImmutableMap.of(SQLConf.STORE_ASSIGNMENT_POLICY().key(), "strict"), + () -> { + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.id = NULL", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Cannot safely cast `id` \"VOID\" to \"INT\""); + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.s.n1 = NULL", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Cannot safely cast `s`.`n1` \"VOID\" to \"INT\""); + + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.s = s.c2", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Cannot find data for the output column `s`.`n2`"); + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.s.n1 = s.c3", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Cannot safely cast `s`.`n1` \"STRING\" to \"INT\"."); + + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.s.n2 = s.c4", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Cannot find data for the output column `s`.`n2`.`dn2`"); + }); + } + + @TestTemplate + public void testMergeWithNonDeterministicConditions() { + createAndInitTable( + "id INT, c STRUCT>", + "{ \"id\": 1, \"c\": { \"n1\": 2, \"n2\": { \"dn1\": 3, \"dn2\": 4 } } }"); + createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }"); + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 AND rand() > t.id " + + "WHEN MATCHED THEN " + + " UPDATE SET t.c.n1 = -1", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + "MERGE operation contains unsupported SEARCH condition. Non-deterministic expressions are not allowed"); + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED AND rand() > t.id THEN " + + " UPDATE SET t.c.n1 = -1", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + "MERGE operation contains unsupported UPDATE condition. Non-deterministic expressions are not allowed"); + + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED AND rand() > t.id THEN " + + " DELETE", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + "MERGE operation contains unsupported DELETE condition. Non-deterministic expressions are not allowed"); + + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN NOT MATCHED AND rand() > c1 THEN " + + " INSERT (id, c) VALUES (1, null)", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + "MERGE operation contains unsupported INSERT condition. Non-deterministic expressions are not allowed"); + } + + @TestTemplate + public void testMergeWithAggregateExpressions() { + createAndInitTable( + "id INT, c STRUCT>", + "{ \"id\": 1, \"c\": { \"n1\": 2, \"n2\": { \"dn1\": 3, \"dn2\": 4 } } }"); + createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }"); + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 AND max(t.id) == 1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.c.n1 = -1", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + "MERGE operation contains unsupported SEARCH condition. Aggregates are not allowed"); + + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED AND sum(t.id) < 1 THEN " + + " UPDATE SET t.c.n1 = -1", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + "MERGE operation contains unsupported UPDATE condition. Aggregates are not allowed"); + + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED AND sum(t.id) THEN " + + " DELETE", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + "MERGE operation contains unsupported DELETE condition. Aggregates are not allowed"); + + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN NOT MATCHED AND sum(c1) < 1 THEN " + + " INSERT (id, c) VALUES (1, null)", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + "MERGE operation contains unsupported INSERT condition. Aggregates are not allowed"); + } + + @TestTemplate + public void testMergeWithSubqueriesInConditions() { + createAndInitTable( + "id INT, c STRUCT>", + "{ \"id\": 1, \"c\": { \"n1\": 2, \"n2\": { \"dn1\": 3, \"dn2\": 4 } } }"); + createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }"); + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 AND t.id < (SELECT max(c2) FROM source) " + + "WHEN MATCHED THEN " + + " UPDATE SET t.c.n1 = s.c2", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + "MERGE operation contains unsupported SEARCH condition. Subqueries are not allowed"); + + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED AND t.id < (SELECT max(c2) FROM source) THEN " + + " UPDATE SET t.c.n1 = s.c2", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + "MERGE operation contains unsupported UPDATE condition. Subqueries are not allowed"); + + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED AND t.id NOT IN (SELECT c2 FROM source) THEN " + + " DELETE", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + "MERGE operation contains unsupported DELETE condition. Subqueries are not allowed"); + + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN NOT MATCHED AND s.c1 IN (SELECT c2 FROM source) THEN " + + " INSERT (id, c) VALUES (1, null)", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + "MERGE operation contains unsupported INSERT condition. Subqueries are not allowed"); + } + + @TestTemplate + public void testMergeWithTargetColumnsInInsertConditions() { + createAndInitTable("id INT, c2 INT", "{ \"id\": 1, \"c2\": 2 }"); + createOrReplaceView("source", "{ \"id\": 1, \"value\": 11 }"); + + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.id " + + "WHEN NOT MATCHED AND c2 = 1 THEN " + + " INSERT (id, c2) VALUES (s.id, null)", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + "A column, variable, or function parameter with name `c2` cannot be resolved"); + } + + @TestTemplate + public void testMergeWithNonIcebergTargetTableNotSupported() { + createOrReplaceView("target", "{ \"c1\": -100, \"c2\": -200 }"); + createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }"); + + assertThatThrownBy( + () -> + sql( + "MERGE INTO target t USING source s " + + "ON t.c1 == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET *")) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessage("MERGE INTO TABLE is not supported temporarily."); + } + + /** + * Tests a merge where both the source and target are evaluated to be partitioned by + * SingePartition at planning time but DynamicFileFilterExec will return an empty target. + */ + @TestTemplate + public void testMergeSinglePartitionPartitioning() { + // This table will only have a single file and a single partition + createAndInitTable("id INT", "{\"id\": -1}"); + + // Coalesce forces our source into a SinglePartition distribution + spark.range(0, 5).coalesce(1).createOrReplaceTempView("source"); + + sql( + "MERGE INTO %s t USING source s ON t.id = s.id " + + "WHEN MATCHED THEN UPDATE SET *" + + "WHEN NOT MATCHED THEN INSERT *", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of(row(-1), row(0), row(1), row(2), row(3), row(4)); + + List result = sql("SELECT * FROM %s ORDER BY id", selectTarget()); + assertEquals("Should correctly add the non-matching rows", expectedRows, result); + } + + @TestTemplate + public void testMergeEmptyTable() { + assumeThat(branch).as("Custom branch does not exist for empty table").isNotEqualTo("test"); + // This table will only have a single file and a single partition + createAndInitTable("id INT", null); + + // Coalesce forces our source into a SinglePartition distribution + spark.range(0, 5).coalesce(1).createOrReplaceTempView("source"); + + sql( + "MERGE INTO %s t USING source s ON t.id = s.id " + + "WHEN MATCHED THEN UPDATE SET *" + + "WHEN NOT MATCHED THEN INSERT *", + commitTarget()); + + ImmutableList expectedRows = ImmutableList.of(row(0), row(1), row(2), row(3), row(4)); + + List result = sql("SELECT * FROM %s ORDER BY id", selectTarget()); + assertEquals("Should correctly add the non-matching rows", expectedRows, result); + } + + @TestTemplate + public void testMergeNonExistingBranch() { + assumeThat(branch).as("Test only applicable to custom branch").isEqualTo("test"); + createAndInitTable("id INT", null); + + // Coalesce forces our source into a SinglePartition distribution + spark.range(0, 5).coalesce(1).createOrReplaceTempView("source"); + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s ON t.id = s.id " + + "WHEN MATCHED THEN UPDATE SET *" + + "WHEN NOT MATCHED THEN INSERT *", + commitTarget())) + .isInstanceOf(ValidationException.class) + .hasMessage("Cannot use branch (does not exist): test"); + } + + @TestTemplate + public void testMergeToWapBranch() { + assumeThat(branch).as("WAP branch only works for table identifier without branch").isNull(); + + createAndInitTable("id INT", "{\"id\": -1}"); + ImmutableList originalRows = ImmutableList.of(row(-1)); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' = 'true')", + tableName, TableProperties.WRITE_AUDIT_PUBLISH_ENABLED); + spark.range(0, 5).coalesce(1).createOrReplaceTempView("source"); + ImmutableList expectedRows = + ImmutableList.of(row(-1), row(0), row(1), row(2), row(3), row(4)); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.WAP_BRANCH, "wap"), + () -> { + sql( + "MERGE INTO %s t USING source s ON t.id = s.id " + + "WHEN MATCHED THEN UPDATE SET *" + + "WHEN NOT MATCHED THEN INSERT *", + tableName); + assertEquals( + "Should have expected rows when reading table", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", tableName)); + assertEquals( + "Should have expected rows when reading WAP branch", + expectedRows, + sql("SELECT * FROM %s.branch_wap ORDER BY id", tableName)); + assertEquals( + "Should not modify main branch", + originalRows, + sql("SELECT * FROM %s.branch_main ORDER BY id", tableName)); + }); + + spark.range(3, 6).coalesce(1).createOrReplaceTempView("source2"); + ImmutableList expectedRows2 = + ImmutableList.of(row(-1), row(0), row(1), row(2), row(5)); + withSQLConf( + ImmutableMap.of(SparkSQLProperties.WAP_BRANCH, "wap"), + () -> { + sql( + "MERGE INTO %s t USING source2 s ON t.id = s.id " + + "WHEN MATCHED THEN DELETE " + + "WHEN NOT MATCHED THEN INSERT *", + tableName); + assertEquals( + "Should have expected rows when reading table with multiple writes", + expectedRows2, + sql("SELECT * FROM %s ORDER BY id", tableName)); + assertEquals( + "Should have expected rows when reading WAP branch with multiple writes", + expectedRows2, + sql("SELECT * FROM %s.branch_wap ORDER BY id", tableName)); + assertEquals( + "Should not modify main branch with multiple writes", + originalRows, + sql("SELECT * FROM %s.branch_main ORDER BY id", tableName)); + }); + } + + @TestTemplate + public void testMergeToWapBranchWithTableBranchIdentifier() { + assumeThat(branch).as("Test must have branch name part in table identifier").isNotNull(); + + createAndInitTable("id INT", "{\"id\": -1}"); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' = 'true')", + tableName, TableProperties.WRITE_AUDIT_PUBLISH_ENABLED); + spark.range(0, 5).coalesce(1).createOrReplaceTempView("source"); + ImmutableList expectedRows = + ImmutableList.of(row(-1), row(0), row(1), row(2), row(3), row(4)); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.WAP_BRANCH, "wap"), + () -> + assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s ON t.id = s.id " + + "WHEN MATCHED THEN UPDATE SET *" + + "WHEN NOT MATCHED THEN INSERT *", + commitTarget())) + .isInstanceOf(ValidationException.class) + .hasMessage( + String.format( + "Cannot write to both branch and WAP branch, but got branch [%s] and WAP branch [wap]", + branch))); + } + + private void checkJoinAndFilterConditions(String query, String join, String icebergFilters) { + // disable runtime filtering for easier validation + withSQLConf( + ImmutableMap.of( + SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED().key(), "false", + SQLConf.RUNTIME_ROW_LEVEL_OPERATION_GROUP_FILTER_ENABLED().key(), "false"), + () -> { + SparkPlan sparkPlan = executeAndKeepPlan(() -> sql(query)); + String planAsString = sparkPlan.toString().replaceAll("#(\\d+L?)", ""); + + assertThat(planAsString).as("Join should match").contains(join + "\n"); + + assertThat(planAsString) + .as("Pushed filters must match") + .contains("[filters=" + icebergFilters + ","); + }); + } + + private RowLevelOperationMode mode(Table table) { + String modeName = table.properties().getOrDefault(MERGE_MODE, MERGE_MODE_DEFAULT); + return RowLevelOperationMode.fromName(modeName); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadDelete.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadDelete.java new file mode 100644 index 000000000000..60941b8d5560 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadDelete.java @@ -0,0 +1,301 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +import java.util.List; +import java.util.Locale; +import java.util.Map; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.RowDelta; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.deletes.DeleteGranularity; +import org.apache.iceberg.exceptions.CommitStateUnknownException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.iceberg.spark.source.SparkTable; +import org.apache.iceberg.spark.source.TestSparkCatalog; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestMergeOnReadDelete extends TestDelete { + + @Override + protected Map extraTableProperties() { + return ImmutableMap.of( + TableProperties.FORMAT_VERSION, + "2", + TableProperties.DELETE_MODE, + RowLevelOperationMode.MERGE_ON_READ.modeName()); + } + + @BeforeEach + public void clearTestSparkCatalogCache() { + TestSparkCatalog.clearTables(); + } + + @TestTemplate + public void testDeleteWithExecutorCacheLocality() throws NoSuchTableException { + createAndInitPartitionedTable(); + + append(tableName, new Employee(1, "hr"), new Employee(2, "hr")); + append(tableName, new Employee(3, "hr"), new Employee(4, "hr")); + append(tableName, new Employee(1, "hardware"), new Employee(2, "hardware")); + append(tableName, new Employee(3, "hardware"), new Employee(4, "hardware")); + + createBranchIfNeeded(); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.EXECUTOR_CACHE_LOCALITY_ENABLED, "true"), + () -> { + sql("DELETE FROM %s WHERE id = 1", commitTarget()); + sql("DELETE FROM %s WHERE id = 3", commitTarget()); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2, "hardware"), row(2, "hr"), row(4, "hardware"), row(4, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC, dep ASC", selectTarget())); + }); + } + + @TestTemplate + public void testDeleteFileGranularity() throws NoSuchTableException { + checkDeleteFileGranularity(DeleteGranularity.FILE); + } + + @TestTemplate + public void testDeletePartitionGranularity() throws NoSuchTableException { + checkDeleteFileGranularity(DeleteGranularity.PARTITION); + } + + @TestTemplate + public void testPositionDeletesAreMaintainedDuringDelete() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id int, data string) USING iceberg PARTITIONED BY (id) TBLPROPERTIES" + + "('%s'='%s', '%s'='%s', '%s'='%s')", + tableName, + TableProperties.FORMAT_VERSION, + 2, + TableProperties.DELETE_MODE, + "merge-on-read", + TableProperties.DELETE_GRANULARITY, + "file"); + createBranchIfNeeded(); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(1, "b"), + new SimpleRecord(1, "c"), + new SimpleRecord(2, "d"), + new SimpleRecord(2, "e")); + spark + .createDataset(records, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(commitTarget()) + .append(); + + sql("DELETE FROM %s WHERE id = 1 and data='a'", commitTarget()); + sql("DELETE FROM %s WHERE id = 2 and data='d'", commitTarget()); + sql("DELETE FROM %s WHERE id = 1 and data='c'", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot latest = SnapshotUtil.latestSnapshot(table, branch); + assertThat(latest.removedDeleteFiles(table.io())).hasSize(1); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "b"), row(2, "e")), + sql("SELECT * FROM %s ORDER BY id ASC", selectTarget())); + } + + @TestTemplate + public void testUnpartitionedPositionDeletesAreMaintainedDuringDelete() + throws NoSuchTableException { + sql( + "CREATE TABLE %s (id int, data string) USING iceberg TBLPROPERTIES" + + "('%s'='%s', '%s'='%s', '%s'='%s')", + tableName, + TableProperties.FORMAT_VERSION, + 2, + TableProperties.DELETE_MODE, + "merge-on-read", + TableProperties.DELETE_GRANULARITY, + "file"); + createBranchIfNeeded(); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(1, "b"), + new SimpleRecord(1, "c"), + new SimpleRecord(2, "d"), + new SimpleRecord(2, "e")); + spark + .createDataset(records, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(commitTarget()) + .append(); + + sql("DELETE FROM %s WHERE id = 1 and data='a'", commitTarget()); + sql("DELETE FROM %s WHERE id = 2 and data='d'", commitTarget()); + sql("DELETE FROM %s WHERE id = 1 and data='c'", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot latest = SnapshotUtil.latestSnapshot(table, branch); + assertThat(latest.removedDeleteFiles(table.io())).hasSize(1); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "b"), row(2, "e")), + sql("SELECT * FROM %s ORDER BY id ASC", selectTarget())); + } + + private void checkDeleteFileGranularity(DeleteGranularity deleteGranularity) + throws NoSuchTableException { + createAndInitPartitionedTable(); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' '%s')", + tableName, TableProperties.DELETE_GRANULARITY, deleteGranularity); + + append(tableName, new Employee(1, "hr"), new Employee(2, "hr")); + append(tableName, new Employee(3, "hr"), new Employee(4, "hr")); + append(tableName, new Employee(1, "hardware"), new Employee(2, "hardware")); + append(tableName, new Employee(3, "hardware"), new Employee(4, "hardware")); + + createBranchIfNeeded(); + + sql("DELETE FROM %s WHERE id = 1 OR id = 3", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).hasSize(5); + + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + String expectedDeleteFilesCount = deleteGranularity == DeleteGranularity.FILE ? "4" : "2"; + validateMergeOnRead(currentSnapshot, "2", expectedDeleteFilesCount, null); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2, "hardware"), row(2, "hr"), row(4, "hardware"), row(4, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC, dep ASC", selectTarget())); + } + + @TestTemplate + public void testCommitUnknownException() { + createAndInitTable("id INT, dep STRING, category STRING"); + + // write unpartitioned files + append(tableName, "{ \"id\": 1, \"dep\": \"hr\", \"category\": \"c1\"}"); + createBranchIfNeeded(); + append( + commitTarget(), + "{ \"id\": 2, \"dep\": \"hr\", \"category\": \"c1\" }\n" + + "{ \"id\": 3, \"dep\": \"hr\", \"category\": \"c1\" }"); + + Table table = validationCatalog.loadTable(tableIdent); + + RowDelta newRowDelta = table.newRowDelta(); + if (branch != null) { + newRowDelta.toBranch(branch); + } + + RowDelta spyNewRowDelta = spy(newRowDelta); + doAnswer( + invocation -> { + newRowDelta.commit(); + throw new CommitStateUnknownException(new RuntimeException("Datacenter on Fire")); + }) + .when(spyNewRowDelta) + .commit(); + + Table spyTable = spy(table); + when(spyTable.newRowDelta()).thenReturn(spyNewRowDelta); + SparkTable sparkTable = + branch == null ? new SparkTable(spyTable, false) : new SparkTable(spyTable, branch, false); + + ImmutableMap config = + ImmutableMap.of( + "type", "hive", + "default-namespace", "default"); + spark + .conf() + .set("spark.sql.catalog.dummy_catalog", "org.apache.iceberg.spark.source.TestSparkCatalog"); + config.forEach( + (key, value) -> spark.conf().set("spark.sql.catalog.dummy_catalog." + key, value)); + Identifier ident = Identifier.of(new String[] {"default"}, "table"); + TestSparkCatalog.setTable(ident, sparkTable); + + // Although an exception is thrown here, write and commit have succeeded + assertThatThrownBy(() -> sql("DELETE FROM %s WHERE id = 2", "dummy_catalog.default.table")) + .isInstanceOf(CommitStateUnknownException.class) + .hasMessageStartingWith("Datacenter on Fire"); + + // Since write and commit succeeded, the rows should be readable + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr", "c1"), row(3, "hr", "c1")), + sql("SELECT * FROM %s ORDER BY id", "dummy_catalog.default.table")); + } + + @TestTemplate + public void testAggregatePushDownInMergeOnReadDelete() { + createAndInitTable("id LONG, data INT"); + sql( + "INSERT INTO TABLE %s VALUES (1, 1111), (1, 2222), (2, 3333), (2, 4444), (3, 5555), (3, 6666) ", + tableName); + createBranchIfNeeded(); + + sql("DELETE FROM %s WHERE data = 1111", commitTarget()); + String select = "SELECT max(data), min(data), count(data) FROM %s"; + + List explain = sql("EXPLAIN " + select, selectTarget()); + String explainString = explain.get(0)[0].toString().toLowerCase(Locale.ROOT); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("max(data)") + || explainString.contains("min(data)") + || explainString.contains("count(data)")) { + explainContainsPushDownAggregates = true; + } + + assertThat(explainContainsPushDownAggregates) + .as("min/max/count not pushed down for deleted") + .isFalse(); + + List actual = sql(select, selectTarget()); + List expected = Lists.newArrayList(); + expected.add(new Object[] {6666, 2222, 5L}); + assertEquals("min/max/count push down", expected, actual); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadMerge.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadMerge.java new file mode 100644 index 000000000000..71ca3421f28d --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadMerge.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.Map; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.deletes.DeleteGranularity; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.sql.Encoders; +import org.junit.jupiter.api.TestTemplate; + +public class TestMergeOnReadMerge extends TestMerge { + + @Override + protected Map extraTableProperties() { + return ImmutableMap.of( + TableProperties.FORMAT_VERSION, + "2", + TableProperties.MERGE_MODE, + RowLevelOperationMode.MERGE_ON_READ.modeName()); + } + + @TestTemplate + public void testMergeDeleteFileGranularity() { + checkMergeDeleteGranularity(DeleteGranularity.FILE); + } + + @TestTemplate + public void testMergeDeletePartitionGranularity() { + checkMergeDeleteGranularity(DeleteGranularity.PARTITION); + } + + private void checkMergeDeleteGranularity(DeleteGranularity deleteGranularity) { + createAndInitTable("id INT, dep STRING", "PARTITIONED BY (dep)", null /* empty */); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' '%s')", + tableName, TableProperties.DELETE_GRANULARITY, deleteGranularity); + + append(tableName, "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 2, \"dep\": \"hr\" }"); + append(tableName, "{ \"id\": 3, \"dep\": \"hr\" }\n" + "{ \"id\": 4, \"dep\": \"hr\" }"); + append(tableName, "{ \"id\": 1, \"dep\": \"it\" }\n" + "{ \"id\": 2, \"dep\": \"it\" }"); + append(tableName, "{ \"id\": 3, \"dep\": \"it\" }\n" + "{ \"id\": 4, \"dep\": \"it\" }"); + + createBranchIfNeeded(); + + createOrReplaceView("source", ImmutableList.of(1, 3, 5), Encoders.INT()); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.value " + + "WHEN MATCHED THEN " + + " DELETE " + + "WHEN NOT MATCHED THEN " + + " INSERT (id, dep) VALUES (-1, 'other')", + commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).hasSize(5); + + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + String expectedDeleteFilesCount = deleteGranularity == DeleteGranularity.FILE ? "4" : "2"; + validateMergeOnRead(currentSnapshot, "3", expectedDeleteFilesCount, "1"); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "other"), row(2, "hr"), row(2, "it"), row(4, "hr"), row(4, "it")), + sql("SELECT * FROM %s ORDER BY id ASC, dep ASC", selectTarget())); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadUpdate.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadUpdate.java new file mode 100644 index 000000000000..e9cc9d8541ad --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadUpdate.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assumptions.assumeThat; + +import java.util.Map; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.deletes.DeleteGranularity; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.util.SnapshotUtil; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestMergeOnReadUpdate extends TestUpdate { + + @Override + protected Map extraTableProperties() { + return ImmutableMap.of( + TableProperties.FORMAT_VERSION, + "2", + TableProperties.UPDATE_MODE, + RowLevelOperationMode.MERGE_ON_READ.modeName()); + } + + @TestTemplate + public void testUpdateFileGranularity() { + checkUpdateFileGranularity(DeleteGranularity.FILE); + } + + @TestTemplate + public void testUpdatePartitionGranularity() { + checkUpdateFileGranularity(DeleteGranularity.PARTITION); + } + + @TestTemplate + public void testUpdateFileGranularityMergesDeleteFiles() { + // Range distribution will produce partition scoped deletes which will not be cleaned up + assumeThat(distributionMode).isNotEqualToIgnoringCase("range"); + + checkUpdateFileGranularity(DeleteGranularity.FILE); + sql("UPDATE %s SET id = id + 1 WHERE id = 4", commitTarget()); + Table table = validationCatalog.loadTable(tableIdent); + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + String expectedDeleteFilesCount = "2"; + validateMergeOnRead(currentSnapshot, "2", expectedDeleteFilesCount, "2"); + + assertThat(currentSnapshot.removedDeleteFiles(table.io())).hasSize(2); + assertEquals( + "Should have expected rows", + ImmutableList.of( + row(0, "hr"), + row(2, "hr"), + row(2, "hr"), + row(5, "hr"), + row(0, "it"), + row(2, "it"), + row(2, "it"), + row(5, "it")), + sql("SELECT * FROM %s ORDER BY dep ASC, id ASC", selectTarget())); + } + + @TestTemplate + public void testUpdateUnpartitionedFileGranularityMergesDeleteFiles() { + // Range distribution will produce partition scoped deletes which will not be cleaned up + assumeThat(distributionMode).isNotEqualToIgnoringCase("range"); + initTable("", DeleteGranularity.FILE); + + sql("UPDATE %s SET id = id - 1 WHERE id = 1 OR id = 3", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).hasSize(5); + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + String expectedDeleteFilesCount = "4"; + validateMergeOnRead(currentSnapshot, "1", expectedDeleteFilesCount, "1"); + assertEquals( + "Should have expected rows", + ImmutableList.of( + row(0, "hr"), + row(2, "hr"), + row(2, "hr"), + row(4, "hr"), + row(0, "it"), + row(2, "it"), + row(2, "it"), + row(4, "it")), + sql("SELECT * FROM %s ORDER BY dep ASC, id ASC", selectTarget())); + + sql("UPDATE %s SET id = id + 1 WHERE id = 4", commitTarget()); + table.refresh(); + currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + expectedDeleteFilesCount = "2"; + + validateMergeOnRead(currentSnapshot, "1", expectedDeleteFilesCount, "1"); + assertThat(currentSnapshot.removedDeleteFiles(table.io())).hasSize(2); + assertEquals( + "Should have expected rows", + ImmutableList.of( + row(0, "hr"), + row(2, "hr"), + row(2, "hr"), + row(5, "hr"), + row(0, "it"), + row(2, "it"), + row(2, "it"), + row(5, "it")), + sql("SELECT * FROM %s ORDER BY dep ASC, id ASC", selectTarget())); + } + + private void checkUpdateFileGranularity(DeleteGranularity deleteGranularity) { + initTable("PARTITIONED BY (dep)", deleteGranularity); + + sql("UPDATE %s SET id = id - 1 WHERE id = 1 OR id = 3", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).hasSize(5); + + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + String expectedDeleteFilesCount = deleteGranularity == DeleteGranularity.FILE ? "4" : "2"; + validateMergeOnRead(currentSnapshot, "2", expectedDeleteFilesCount, "2"); + + assertEquals( + "Should have expected rows", + ImmutableList.of( + row(0, "hr"), + row(2, "hr"), + row(2, "hr"), + row(4, "hr"), + row(0, "it"), + row(2, "it"), + row(2, "it"), + row(4, "it")), + sql("SELECT * FROM %s ORDER BY dep ASC, id ASC", selectTarget())); + } + + private void initTable(String partitionedBy, DeleteGranularity deleteGranularity) { + createAndInitTable("id INT, dep STRING", partitionedBy, null /* empty */); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' '%s')", + tableName, TableProperties.DELETE_GRANULARITY, deleteGranularity); + + append(tableName, "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 2, \"dep\": \"hr\" }"); + append(tableName, "{ \"id\": 3, \"dep\": \"hr\" }\n" + "{ \"id\": 4, \"dep\": \"hr\" }"); + append(tableName, "{ \"id\": 1, \"dep\": \"it\" }\n" + "{ \"id\": 2, \"dep\": \"it\" }"); + append(tableName, "{ \"id\": 3, \"dep\": \"it\" }\n" + "{ \"id\": 4, \"dep\": \"it\" }"); + + createBranchIfNeeded(); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMetaColumnProjectionWithStageScan.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMetaColumnProjectionWithStageScan.java new file mode 100644 index 000000000000..b783a006ef73 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMetaColumnProjectionWithStageScan.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import java.util.UUID; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.ScanTask; +import org.apache.iceberg.Table; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.ScanTaskSetManager; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestMetaColumnProjectionWithStageScan extends ExtensionsTestBase { + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + public static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.HADOOP.catalogName(), + SparkCatalogConfig.HADOOP.implementation(), + SparkCatalogConfig.HADOOP.properties() + } + }; + } + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + private void stageTask( + Table tab, String fileSetID, CloseableIterable tasks) { + ScanTaskSetManager taskSetManager = ScanTaskSetManager.get(); + taskSetManager.stageTasks(tab, fileSetID, Lists.newArrayList(tasks)); + } + + @TestTemplate + public void testReadStageTableMeta() throws Exception { + sql( + "CREATE TABLE %s (id bigint, data string) USING iceberg TBLPROPERTIES" + + "('format-version'='2', 'write.delete.mode'='merge-on-read')", + tableName); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c"), + new SimpleRecord(4, "d")); + + spark + .createDataset(records, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + table.refresh(); + String tableLocation = table.location(); + + try (CloseableIterable tasks = table.newBatchScan().planFiles()) { + String fileSetID = UUID.randomUUID().toString(); + stageTask(table, fileSetID, tasks); + Dataset scanDF2 = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.FILE_OPEN_COST, "0") + .option(SparkReadOptions.SCAN_TASK_SET_ID, fileSetID) + .load(tableLocation); + + assertThat(scanDF2.columns()).hasSize(2); + } + + try (CloseableIterable tasks = table.newBatchScan().planFiles()) { + String fileSetID = UUID.randomUUID().toString(); + stageTask(table, fileSetID, tasks); + Dataset scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.FILE_OPEN_COST, "0") + .option(SparkReadOptions.SCAN_TASK_SET_ID, fileSetID) + .load(tableLocation) + .select("*", "_pos"); + + List rows = scanDF.collectAsList(); + ImmutableList expectedRows = + ImmutableList.of(row(1L, "a", 0L), row(2L, "b", 1L), row(3L, "c", 2L), row(4L, "d", 3L)); + assertEquals("result should match", expectedRows, rowsToJava(rows)); + } + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMetadataTables.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMetadataTables.java new file mode 100644 index 000000000000..a22cf61ec8c9 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMetadataTables.java @@ -0,0 +1,850 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.IOException; +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.avro.generic.GenericData.Record; +import org.apache.commons.collections.ListUtils; +import org.apache.iceberg.FileContent; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.HistoryEntry; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.apache.spark.sql.types.StructType; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestMetadataTables extends ExtensionsTestBase { + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testUnpartitionedTable() throws Exception { + sql( + "CREATE TABLE %s (id bigint, data string) USING iceberg TBLPROPERTIES" + + "('format-version'='2', 'write.delete.mode'='merge-on-read')", + tableName); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c"), + new SimpleRecord(4, "d")); + spark + .createDataset(records, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + + sql("DELETE FROM %s WHERE id=1", tableName); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + List expectedDataManifests = TestHelpers.dataManifests(table); + List expectedDeleteManifests = TestHelpers.deleteManifests(table); + assertThat(expectedDataManifests).as("Should have 1 data manifest").hasSize(1); + assertThat(expectedDeleteManifests).as("Should have 1 delete manifest").hasSize(1); + + Schema entriesTableSchema = Spark3Util.loadIcebergTable(spark, tableName + ".entries").schema(); + Schema filesTableSchema = Spark3Util.loadIcebergTable(spark, tableName + ".files").schema(); + + // check delete files table + Dataset actualDeleteFilesDs = spark.sql("SELECT * FROM " + tableName + ".delete_files"); + List actualDeleteFiles = TestHelpers.selectNonDerived(actualDeleteFilesDs).collectAsList(); + assertThat(actualDeleteFiles).as("Metadata table should return one delete file").hasSize(1); + + List expectedDeleteFiles = + expectedEntries( + table, FileContent.POSITION_DELETES, entriesTableSchema, expectedDeleteManifests, null); + assertThat(expectedDeleteFiles).as("Should be one delete file manifest entry").hasSize(1); + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(actualDeleteFilesDs), + expectedDeleteFiles.get(0), + actualDeleteFiles.get(0)); + + // check data files table + Dataset actualDataFilesDs = spark.sql("SELECT * FROM " + tableName + ".data_files"); + List actualDataFiles = TestHelpers.selectNonDerived(actualDataFilesDs).collectAsList(); + assertThat(actualDataFiles).as("Metadata table should return one data file").hasSize(1); + + List expectedDataFiles = + expectedEntries(table, FileContent.DATA, entriesTableSchema, expectedDataManifests, null); + assertThat(expectedDataFiles).as("Should be one data file manifest entry").hasSize(1); + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(actualDataFilesDs), + expectedDataFiles.get(0), + actualDataFiles.get(0)); + + // check all files table + Dataset actualFilesDs = + spark.sql("SELECT * FROM " + tableName + ".files ORDER BY content"); + List actualFiles = TestHelpers.selectNonDerived(actualFilesDs).collectAsList(); + + assertThat(actualFiles).as("Metadata table should return two files").hasSize(2); + + List expectedFiles = + Stream.concat(expectedDataFiles.stream(), expectedDeleteFiles.stream()) + .collect(Collectors.toList()); + assertThat(expectedFiles).as("Should have two files manifest entries").hasSize(2); + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(actualFilesDs), expectedFiles.get(0), actualFiles.get(0)); + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(actualFilesDs), expectedFiles.get(1), actualFiles.get(1)); + } + + @TestTemplate + public void testPartitionedTable() throws Exception { + sql( + "CREATE TABLE %s (id bigint, data string) " + + "USING iceberg " + + "PARTITIONED BY (data) " + + "TBLPROPERTIES" + + "('format-version'='2', 'write.delete.mode'='merge-on-read')", + tableName); + + List recordsA = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(2, "a")); + spark + .createDataset(recordsA, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + + List recordsB = + Lists.newArrayList(new SimpleRecord(1, "b"), new SimpleRecord(2, "b")); + spark + .createDataset(recordsB, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + + sql("DELETE FROM %s WHERE id=1 AND data='a'", tableName); + sql("DELETE FROM %s WHERE id=1 AND data='b'", tableName); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + Schema entriesTableSchema = Spark3Util.loadIcebergTable(spark, tableName + ".entries").schema(); + + List expectedDataManifests = TestHelpers.dataManifests(table); + List expectedDeleteManifests = TestHelpers.deleteManifests(table); + assertThat(expectedDataManifests).as("Should have 2 data manifest").hasSize(2); + assertThat(expectedDeleteManifests).as("Should have 2 delete manifest").hasSize(2); + + Schema filesTableSchema = + Spark3Util.loadIcebergTable(spark, tableName + ".delete_files").schema(); + + // Check delete files table + List expectedDeleteFiles = + expectedEntries( + table, FileContent.POSITION_DELETES, entriesTableSchema, expectedDeleteManifests, "a"); + assertThat(expectedDeleteFiles).as("Should have one delete file manifest entry").hasSize(1); + + Dataset actualDeleteFilesDs = + spark.sql("SELECT * FROM " + tableName + ".delete_files " + "WHERE partition.data='a'"); + List actualDeleteFiles = actualDeleteFilesDs.collectAsList(); + + assertThat(actualDeleteFiles).as("Metadata table should return one delete file").hasSize(1); + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(actualDeleteFilesDs), + expectedDeleteFiles.get(0), + actualDeleteFiles.get(0)); + + // Check data files table + List expectedDataFiles = + expectedEntries(table, FileContent.DATA, entriesTableSchema, expectedDataManifests, "a"); + assertThat(expectedDataFiles).as("Should have one data file manifest entry").hasSize(1); + + Dataset actualDataFilesDs = + spark.sql("SELECT * FROM " + tableName + ".data_files " + "WHERE partition.data='a'"); + + List actualDataFiles = TestHelpers.selectNonDerived(actualDataFilesDs).collectAsList(); + assertThat(actualDataFiles).as("Metadata table should return one data file").hasSize(1); + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(actualDataFilesDs), + expectedDataFiles.get(0), + actualDataFiles.get(0)); + + List actualPartitionsWithProjection = + spark.sql("SELECT file_count FROM " + tableName + ".partitions ").collectAsList(); + assertThat(actualPartitionsWithProjection) + .as("Metadata table should return two partitions record") + .hasSize(2) + .containsExactly(RowFactory.create(1), RowFactory.create(1)); + + // Check files table + List expectedFiles = + Stream.concat(expectedDataFiles.stream(), expectedDeleteFiles.stream()) + .collect(Collectors.toList()); + assertThat(expectedFiles).as("Should have two file manifest entries").hasSize(2); + + Dataset actualFilesDs = + spark.sql( + "SELECT * FROM " + tableName + ".files " + "WHERE partition.data='a' ORDER BY content"); + List actualFiles = TestHelpers.selectNonDerived(actualFilesDs).collectAsList(); + assertThat(actualFiles).as("Metadata table should return two files").hasSize(2); + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(actualFilesDs), expectedFiles.get(0), actualFiles.get(0)); + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(actualFilesDs), expectedFiles.get(1), actualFiles.get(1)); + } + + @TestTemplate + public void testAllFilesUnpartitioned() throws Exception { + sql( + "CREATE TABLE %s (id bigint, data string) USING iceberg TBLPROPERTIES" + + "('format-version'='2', 'write.delete.mode'='merge-on-read')", + tableName); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c"), + new SimpleRecord(4, "d")); + spark + .createDataset(records, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + + // Create delete file + sql("DELETE FROM %s WHERE id=1", tableName); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + List expectedDataManifests = TestHelpers.dataManifests(table); + assertThat(expectedDataManifests).as("Should have 1 data manifest").hasSize(1); + List expectedDeleteManifests = TestHelpers.deleteManifests(table); + assertThat(expectedDeleteManifests).as("Should have 1 delete manifest").hasSize(1); + + // Clear table to test whether 'all_files' can read past files + List results = sql("DELETE FROM %s", tableName); + assertThat(results).as("Table should be cleared").isEmpty(); + + Schema entriesTableSchema = Spark3Util.loadIcebergTable(spark, tableName + ".entries").schema(); + Schema filesTableSchema = + Spark3Util.loadIcebergTable(spark, tableName + ".all_data_files").schema(); + + // Check all data files table + Dataset actualDataFilesDs = spark.sql("SELECT * FROM " + tableName + ".all_data_files"); + List actualDataFiles = TestHelpers.selectNonDerived(actualDataFilesDs).collectAsList(); + + List expectedDataFiles = + expectedEntries(table, FileContent.DATA, entriesTableSchema, expectedDataManifests, null); + assertThat(expectedDataFiles).as("Should be one data file manifest entry").hasSize(1); + assertThat(actualDataFiles).as("Metadata table should return one data file").hasSize(1); + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(actualDataFilesDs), + expectedDataFiles.get(0), + actualDataFiles.get(0)); + + // Check all delete files table + Dataset actualDeleteFilesDs = + spark.sql("SELECT * FROM " + tableName + ".all_delete_files"); + List actualDeleteFiles = TestHelpers.selectNonDerived(actualDeleteFilesDs).collectAsList(); + List expectedDeleteFiles = + expectedEntries( + table, FileContent.POSITION_DELETES, entriesTableSchema, expectedDeleteManifests, null); + assertThat(expectedDeleteFiles).as("Should be one delete file manifest entry").hasSize(1); + assertThat(actualDeleteFiles).as("Metadata table should return one delete file").hasSize(1); + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(actualDeleteFilesDs), + expectedDeleteFiles.get(0), + actualDeleteFiles.get(0)); + + // Check all files table + Dataset actualFilesDs = + spark.sql("SELECT * FROM " + tableName + ".all_files ORDER BY content"); + List actualFiles = actualFilesDs.collectAsList(); + List expectedFiles = ListUtils.union(expectedDataFiles, expectedDeleteFiles); + expectedFiles.sort(Comparator.comparing(r -> ((Integer) r.get("content")))); + assertThat(actualFiles).as("Metadata table should return two files").hasSize(2); + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(actualFilesDs), expectedFiles, actualFiles); + } + + @TestTemplate + public void testAllFilesPartitioned() throws Exception { + // Create table and insert data + sql( + "CREATE TABLE %s (id bigint, data string) " + + "USING iceberg " + + "PARTITIONED BY (data) " + + "TBLPROPERTIES" + + "('format-version'='2', 'write.delete.mode'='merge-on-read')", + tableName); + + List recordsA = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(2, "a")); + spark + .createDataset(recordsA, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + + List recordsB = + Lists.newArrayList(new SimpleRecord(1, "b"), new SimpleRecord(2, "b")); + spark + .createDataset(recordsB, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + + // Create delete file + sql("DELETE FROM %s WHERE id=1", tableName); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + List expectedDataManifests = TestHelpers.dataManifests(table); + assertThat(expectedDataManifests).as("Should have 2 data manifests").hasSize(2); + List expectedDeleteManifests = TestHelpers.deleteManifests(table); + assertThat(expectedDeleteManifests).as("Should have 1 delete manifest").hasSize(1); + + // Clear table to test whether 'all_files' can read past files + List results = sql("DELETE FROM %s", tableName); + assertThat(results).as("Table should be cleared").isEmpty(); + + Schema entriesTableSchema = Spark3Util.loadIcebergTable(spark, tableName + ".entries").schema(); + Schema filesTableSchema = + Spark3Util.loadIcebergTable(spark, tableName + ".all_data_files").schema(); + + // Check all data files table + Dataset actualDataFilesDs = + spark.sql("SELECT * FROM " + tableName + ".all_data_files " + "WHERE partition.data='a'"); + List actualDataFiles = TestHelpers.selectNonDerived(actualDataFilesDs).collectAsList(); + List expectedDataFiles = + expectedEntries(table, FileContent.DATA, entriesTableSchema, expectedDataManifests, "a"); + assertThat(expectedDataFiles).as("Should be one data file manifest entry").hasSize(1); + assertThat(actualDataFiles).as("Metadata table should return one data file").hasSize(1); + TestHelpers.assertEqualsSafe( + SparkSchemaUtil.convert(TestHelpers.selectNonDerived(actualDataFilesDs).schema()) + .asStruct(), + expectedDataFiles.get(0), + actualDataFiles.get(0)); + + // Check all delete files table + Dataset actualDeleteFilesDs = + spark.sql("SELECT * FROM " + tableName + ".all_delete_files " + "WHERE partition.data='a'"); + List actualDeleteFiles = TestHelpers.selectNonDerived(actualDeleteFilesDs).collectAsList(); + + List expectedDeleteFiles = + expectedEntries( + table, FileContent.POSITION_DELETES, entriesTableSchema, expectedDeleteManifests, "a"); + assertThat(expectedDeleteFiles).as("Should be one data file manifest entry").hasSize(1); + assertThat(actualDeleteFiles).as("Metadata table should return one data file").hasSize(1); + + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(actualDeleteFilesDs), + expectedDeleteFiles.get(0), + actualDeleteFiles.get(0)); + + // Check all files table + Dataset actualFilesDs = + spark.sql( + "SELECT * FROM " + + tableName + + ".all_files WHERE partition.data='a' " + + "ORDER BY content"); + List actualFiles = TestHelpers.selectNonDerived(actualFilesDs).collectAsList(); + + List expectedFiles = ListUtils.union(expectedDataFiles, expectedDeleteFiles); + expectedFiles.sort(Comparator.comparing(r -> ((Integer) r.get("content")))); + assertThat(actualFiles).as("Metadata table should return two files").hasSize(2); + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(actualDataFilesDs), expectedFiles, actualFiles); + } + + @TestTemplate + public void testMetadataLogEntries() throws Exception { + // Create table and insert data + sql( + "CREATE TABLE %s (id bigint, data string) " + + "USING iceberg " + + "PARTITIONED BY (data) " + + "TBLPROPERTIES " + + "('format-version'='2')", + tableName); + + List recordsA = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(2, "a")); + spark.createDataset(recordsA, Encoders.bean(SimpleRecord.class)).writeTo(tableName).append(); + + List recordsB = + Lists.newArrayList(new SimpleRecord(1, "b"), new SimpleRecord(2, "b")); + spark.createDataset(recordsB, Encoders.bean(SimpleRecord.class)).writeTo(tableName).append(); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + Long currentSnapshotId = table.currentSnapshot().snapshotId(); + TableMetadata tableMetadata = ((HasTableOperations) table).operations().current(); + Snapshot currentSnapshot = tableMetadata.currentSnapshot(); + Snapshot parentSnapshot = table.snapshot(currentSnapshot.parentId()); + List metadataLogEntries = + Lists.newArrayList(tableMetadata.previousFiles()); + + // Check metadataLog table + List metadataLogs = sql("SELECT * FROM %s.metadata_log_entries", tableName); + assertEquals( + "MetadataLogEntriesTable result should match the metadataLog entries", + ImmutableList.of( + row( + DateTimeUtils.toJavaTimestamp(metadataLogEntries.get(0).timestampMillis() * 1000), + metadataLogEntries.get(0).file(), + null, + null, + null), + row( + DateTimeUtils.toJavaTimestamp(metadataLogEntries.get(1).timestampMillis() * 1000), + metadataLogEntries.get(1).file(), + parentSnapshot.snapshotId(), + parentSnapshot.schemaId(), + parentSnapshot.sequenceNumber()), + row( + DateTimeUtils.toJavaTimestamp(currentSnapshot.timestampMillis() * 1000), + tableMetadata.metadataFileLocation(), + currentSnapshot.snapshotId(), + currentSnapshot.schemaId(), + currentSnapshot.sequenceNumber())), + metadataLogs); + + // test filtering + List metadataLogWithFilters = + sql( + "SELECT * FROM %s.metadata_log_entries WHERE latest_snapshot_id = %s", + tableName, currentSnapshotId); + assertThat(metadataLogWithFilters) + .as("metadataLogEntries table should return 1 row") + .hasSize(1); + assertEquals( + "Result should match the latest snapshot entry", + ImmutableList.of( + row( + DateTimeUtils.toJavaTimestamp( + tableMetadata.currentSnapshot().timestampMillis() * 1000), + tableMetadata.metadataFileLocation(), + tableMetadata.currentSnapshot().snapshotId(), + tableMetadata.currentSnapshot().schemaId(), + tableMetadata.currentSnapshot().sequenceNumber())), + metadataLogWithFilters); + + // test projection + List metadataFiles = + metadataLogEntries.stream() + .map(TableMetadata.MetadataLogEntry::file) + .collect(Collectors.toList()); + metadataFiles.add(tableMetadata.metadataFileLocation()); + List metadataLogWithProjection = + sql("SELECT file FROM %s.metadata_log_entries", tableName); + assertThat(metadataLogWithProjection) + .as("metadataLogEntries table should return 3 rows") + .hasSize(3); + assertEquals( + "metadataLog entry should be of same file", + metadataFiles.stream().map(this::row).collect(Collectors.toList()), + metadataLogWithProjection); + } + + @TestTemplate + public void testFilesTableTimeTravelWithSchemaEvolution() throws Exception { + // Create table and insert data + sql( + "CREATE TABLE %s (id bigint, data string) " + + "USING iceberg " + + "PARTITIONED BY (data) " + + "TBLPROPERTIES" + + "('format-version'='2', 'write.delete.mode'='merge-on-read')", + tableName); + + List recordsA = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(2, "a")); + spark + .createDataset(recordsA, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + + table.updateSchema().addColumn("category", Types.StringType.get()).commit(); + + List newRecords = + Lists.newArrayList(RowFactory.create(3, "b", "c"), RowFactory.create(4, "b", "c")); + + StructType newSparkSchema = + SparkSchemaUtil.convert( + new Schema( + optional(1, "id", Types.IntegerType.get()), + optional(2, "data", Types.StringType.get()), + optional(3, "category", Types.StringType.get()))); + + spark.createDataFrame(newRecords, newSparkSchema).coalesce(1).writeTo(tableName).append(); + + Long currentSnapshotId = table.currentSnapshot().snapshotId(); + + Dataset actualFilesDs = + spark.sql( + "SELECT * FROM " + + tableName + + ".files VERSION AS OF " + + currentSnapshotId + + " ORDER BY content"); + List actualFiles = TestHelpers.selectNonDerived(actualFilesDs).collectAsList(); + Schema entriesTableSchema = Spark3Util.loadIcebergTable(spark, tableName + ".entries").schema(); + List expectedDataManifests = TestHelpers.dataManifests(table); + List expectedFiles = + expectedEntries(table, FileContent.DATA, entriesTableSchema, expectedDataManifests, null); + + assertThat(actualFiles).as("actualFiles size should be 2").hasSize(2); + + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(actualFilesDs), expectedFiles.get(0), actualFiles.get(0)); + + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(actualFilesDs), expectedFiles.get(1), actualFiles.get(1)); + + assertThat(actualFiles) + .as("expectedFiles and actualFiles size should be the same") + .hasSameSizeAs(expectedFiles); + } + + @TestTemplate + public void testSnapshotReferencesMetatable() throws Exception { + // Create table and insert data + sql( + "CREATE TABLE %s (id bigint, data string) " + + "USING iceberg " + + "PARTITIONED BY (data) " + + "TBLPROPERTIES" + + "('format-version'='2', 'write.delete.mode'='merge-on-read')", + tableName); + + List recordsA = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(2, "a")); + spark + .createDataset(recordsA, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + + List recordsB = + Lists.newArrayList(new SimpleRecord(1, "b"), new SimpleRecord(2, "b")); + spark + .createDataset(recordsB, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + Long currentSnapshotId = table.currentSnapshot().snapshotId(); + + // Create branch + table + .manageSnapshots() + .createBranch("testBranch", currentSnapshotId) + .setMaxRefAgeMs("testBranch", 10) + .setMinSnapshotsToKeep("testBranch", 20) + .setMaxSnapshotAgeMs("testBranch", 30) + .commit(); + // Create Tag + table + .manageSnapshots() + .createTag("testTag", currentSnapshotId) + .setMaxRefAgeMs("testTag", 50) + .commit(); + // Check refs table + List references = spark.sql("SELECT * FROM " + tableName + ".refs").collectAsList(); + assertThat(references).as("Refs table should return 3 rows").hasSize(3); + List branches = + spark.sql("SELECT * FROM " + tableName + ".refs WHERE type='BRANCH'").collectAsList(); + assertThat(branches).as("Refs table should return 2 branches").hasSize(2); + List tags = + spark.sql("SELECT * FROM " + tableName + ".refs WHERE type='TAG'").collectAsList(); + assertThat(tags).as("Refs table should return 1 tag").hasSize(1); + + // Check branch entries in refs table + List mainBranch = + spark + .sql("SELECT * FROM " + tableName + ".refs WHERE name = 'main' AND type='BRANCH'") + .collectAsList(); + assertThat(mainBranch) + .hasSize(1) + .containsExactly(RowFactory.create("main", "BRANCH", currentSnapshotId, null, null, null)); + assertThat(mainBranch.get(0).schema().fieldNames()) + .containsExactly( + "name", + "type", + "snapshot_id", + "max_reference_age_in_ms", + "min_snapshots_to_keep", + "max_snapshot_age_in_ms"); + + List testBranch = + spark + .sql("SELECT * FROM " + tableName + ".refs WHERE name = 'testBranch' AND type='BRANCH'") + .collectAsList(); + assertThat(testBranch) + .hasSize(1) + .containsExactly( + RowFactory.create("testBranch", "BRANCH", currentSnapshotId, 10L, 20L, 30L)); + assertThat(testBranch.get(0).schema().fieldNames()) + .containsExactly( + "name", + "type", + "snapshot_id", + "max_reference_age_in_ms", + "min_snapshots_to_keep", + "max_snapshot_age_in_ms"); + + // Check tag entries in refs table + List testTag = + spark + .sql("SELECT * FROM " + tableName + ".refs WHERE name = 'testTag' AND type='TAG'") + .collectAsList(); + assertThat(testTag) + .hasSize(1) + .containsExactly(RowFactory.create("testTag", "TAG", currentSnapshotId, 50L, null, null)); + assertThat(testTag.get(0).schema().fieldNames()) + .containsExactly( + "name", + "type", + "snapshot_id", + "max_reference_age_in_ms", + "min_snapshots_to_keep", + "max_snapshot_age_in_ms"); + + // Check projection in refs table + List testTagProjection = + spark + .sql( + "SELECT name,type,snapshot_id,max_reference_age_in_ms,min_snapshots_to_keep FROM " + + tableName + + ".refs where type='TAG'") + .collectAsList(); + assertThat(testTagProjection) + .hasSize(1) + .containsExactly(RowFactory.create("testTag", "TAG", currentSnapshotId, 50L, null)); + assertThat(testTagProjection.get(0).schema().fieldNames()) + .containsExactly( + "name", "type", "snapshot_id", "max_reference_age_in_ms", "min_snapshots_to_keep"); + + List mainBranchProjection = + spark + .sql( + "SELECT name, type FROM " + + tableName + + ".refs WHERE name = 'main' AND type = 'BRANCH'") + .collectAsList(); + assertThat(mainBranchProjection) + .hasSize(1) + .containsExactly(RowFactory.create("main", "BRANCH")); + assertThat(mainBranchProjection.get(0).schema().fieldNames()).containsExactly("name", "type"); + + List testBranchProjection = + spark + .sql( + "SELECT name, type, snapshot_id, max_reference_age_in_ms FROM " + + tableName + + ".refs WHERE name = 'testBranch' AND type = 'BRANCH'") + .collectAsList(); + assertThat(testBranchProjection) + .hasSize(1) + .containsExactly(RowFactory.create("testBranch", "BRANCH", currentSnapshotId, 10L)); + assertThat(testBranchProjection.get(0).schema().fieldNames()) + .containsExactly("name", "type", "snapshot_id", "max_reference_age_in_ms"); + } + + /** + * Find matching manifest entries of an Iceberg table + * + * @param table iceberg table + * @param expectedContent file content to populate on entries + * @param entriesTableSchema schema of Manifest entries + * @param manifestsToExplore manifests to explore of the table + * @param partValue partition value that manifest entries must match, or null to skip filtering + */ + private List expectedEntries( + Table table, + FileContent expectedContent, + Schema entriesTableSchema, + List manifestsToExplore, + String partValue) + throws IOException { + List expected = Lists.newArrayList(); + for (ManifestFile manifest : manifestsToExplore) { + InputFile in = table.io().newInputFile(manifest.path()); + try (CloseableIterable rows = Avro.read(in).project(entriesTableSchema).build()) { + for (Record record : rows) { + if ((Integer) record.get("status") < 2 /* added or existing */) { + Record file = (Record) record.get("data_file"); + if (partitionMatch(file, partValue)) { + TestHelpers.asMetadataRecord(file, expectedContent); + expected.add(file); + } + } + } + } + } + return expected; + } + + private boolean partitionMatch(Record file, String partValue) { + if (partValue == null) { + return true; + } + Record partition = (Record) file.get(4); + return partValue.equals(partition.get(0).toString()); + } + + @TestTemplate + public void metadataLogEntriesAfterReplacingTable() throws Exception { + sql( + "CREATE TABLE %s (id bigint, data string) " + + "USING iceberg " + + "PARTITIONED BY (data) " + + "TBLPROPERTIES " + + "('format-version'='2')", + tableName); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + TableMetadata tableMetadata = ((HasTableOperations) table).operations().current(); + assertThat(tableMetadata.snapshots()).isEmpty(); + assertThat(tableMetadata.snapshotLog()).isEmpty(); + assertThat(tableMetadata.currentSnapshot()).isNull(); + + Object[] firstEntry = + row( + DateTimeUtils.toJavaTimestamp(tableMetadata.lastUpdatedMillis() * 1000), + tableMetadata.metadataFileLocation(), + null, + null, + null); + + assertThat(sql("SELECT * FROM %s.metadata_log_entries", tableName)).containsExactly(firstEntry); + + sql("INSERT INTO %s (id, data) VALUES (1, 'a')", tableName); + + tableMetadata = ((HasTableOperations) table).operations().refresh(); + assertThat(tableMetadata.snapshots()).hasSize(1); + assertThat(tableMetadata.snapshotLog()).hasSize(1); + Snapshot currentSnapshot = tableMetadata.currentSnapshot(); + + Object[] secondEntry = + row( + DateTimeUtils.toJavaTimestamp(tableMetadata.lastUpdatedMillis() * 1000), + tableMetadata.metadataFileLocation(), + currentSnapshot.snapshotId(), + currentSnapshot.schemaId(), + currentSnapshot.sequenceNumber()); + + assertThat(sql("SELECT * FROM %s.metadata_log_entries", tableName)) + .containsExactly(firstEntry, secondEntry); + + sql("INSERT INTO %s (id, data) VALUES (1, 'a')", tableName); + + tableMetadata = ((HasTableOperations) table).operations().refresh(); + assertThat(tableMetadata.snapshots()).hasSize(2); + assertThat(tableMetadata.snapshotLog()).hasSize(2); + currentSnapshot = tableMetadata.currentSnapshot(); + + Object[] thirdEntry = + row( + DateTimeUtils.toJavaTimestamp(tableMetadata.lastUpdatedMillis() * 1000), + tableMetadata.metadataFileLocation(), + currentSnapshot.snapshotId(), + currentSnapshot.schemaId(), + currentSnapshot.sequenceNumber()); + + assertThat(sql("SELECT * FROM %s.metadata_log_entries", tableName)) + .containsExactly(firstEntry, secondEntry, thirdEntry); + + sql( + "CREATE OR REPLACE TABLE %s (id bigint, data string) " + + "USING iceberg " + + "PARTITIONED BY (data) " + + "TBLPROPERTIES " + + "('format-version'='2')", + tableName); + + tableMetadata = ((HasTableOperations) table).operations().refresh(); + assertThat(tableMetadata.snapshots()).hasSize(2); + assertThat(tableMetadata.snapshotLog()).hasSize(2); + + // currentSnapshot is null but the metadata_log_entries will refer to the last snapshot from the + // snapshotLog + assertThat(tableMetadata.currentSnapshot()).isNull(); + HistoryEntry historyEntry = tableMetadata.snapshotLog().get(1); + Snapshot lastSnapshot = tableMetadata.snapshot(historyEntry.snapshotId()); + + Object[] fourthEntry = + row( + DateTimeUtils.toJavaTimestamp(tableMetadata.lastUpdatedMillis() * 1000), + tableMetadata.metadataFileLocation(), + lastSnapshot.snapshotId(), + lastSnapshot.schemaId(), + lastSnapshot.sequenceNumber()); + + assertThat(sql("SELECT * FROM %s.metadata_log_entries", tableName)) + .containsExactly(firstEntry, secondEntry, thirdEntry, fourthEntry); + + sql("INSERT INTO %s (id, data) VALUES (1, 'a')", tableName); + + tableMetadata = ((HasTableOperations) table).operations().refresh(); + assertThat(tableMetadata.snapshots()).hasSize(3); + assertThat(tableMetadata.snapshotLog()).hasSize(3); + currentSnapshot = tableMetadata.currentSnapshot(); + + Object[] fifthEntry = + row( + DateTimeUtils.toJavaTimestamp(tableMetadata.lastUpdatedMillis() * 1000), + tableMetadata.metadataFileLocation(), + currentSnapshot.snapshotId(), + currentSnapshot.schemaId(), + currentSnapshot.sequenceNumber()); + + assertThat(sql("SELECT * FROM %s.metadata_log_entries", tableName)) + .containsExactly(firstEntry, secondEntry, thirdEntry, fourthEntry, fifthEntry); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMigrateTableProcedure.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMigrateTableProcedure.java new file mode 100644 index 000000000000..23c08b2572f4 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMigrateTableProcedure.java @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assumptions.assumeThat; + +import java.io.IOException; +import java.nio.file.Files; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.spark.sql.AnalysisException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestMigrateTableProcedure extends ExtensionsTestBase { + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS %s_BACKUP_", tableName); + } + + @TestTemplate + public void testMigrate() throws IOException { + assumeThat(catalogName).isEqualToIgnoringCase("spark_catalog"); + String location = Files.createTempDirectory(temp, "junit").toFile().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + tableName, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + Object result = scalarSql("CALL %s.system.migrate('%s')", catalogName, tableName); + + assertThat(result).as("Should have added one file").isEqualTo(1L); + + Table createdTable = validationCatalog.loadTable(tableIdent); + + String tableLocation = createdTable.location().replace("file:", ""); + assertThat(tableLocation).as("Table should have original location").isEqualTo(location); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + sql("DROP TABLE IF EXISTS %s", tableName + "_BACKUP_"); + } + + @TestTemplate + public void testMigrateWithOptions() throws IOException { + assumeThat(catalogName).isEqualToIgnoringCase("spark_catalog"); + String location = Files.createTempDirectory(temp, "junit").toFile().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + tableName, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Object result = + scalarSql("CALL %s.system.migrate('%s', map('foo', 'bar'))", catalogName, tableName); + + assertThat(result).as("Should have added one file").isEqualTo(1L); + + Table createdTable = validationCatalog.loadTable(tableIdent); + + Map props = createdTable.properties(); + assertThat(props).containsEntry("foo", "bar"); + + String tableLocation = createdTable.location().replace("file:", ""); + assertThat(tableLocation).as("Table should have original location").isEqualTo(location); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + sql("DROP TABLE IF EXISTS %s", tableName + "_BACKUP_"); + } + + @TestTemplate + public void testMigrateWithDropBackup() throws IOException { + assumeThat(catalogName).isEqualToIgnoringCase("spark_catalog"); + String location = Files.createTempDirectory(temp, "junit").toFile().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + tableName, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Object result = + scalarSql( + "CALL %s.system.migrate(table => '%s', drop_backup => true)", catalogName, tableName); + assertThat(result).as("Should have added one file").isEqualTo(1L); + assertThat(spark.catalog().tableExists(tableName + "_BACKUP_")).isFalse(); + } + + @TestTemplate + public void testMigrateWithBackupTableName() throws IOException { + assumeThat(catalogName).isEqualToIgnoringCase("spark_catalog"); + String location = Files.createTempDirectory(temp, "junit").toFile().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + tableName, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + String backupTableName = "backup_table"; + Object result = + scalarSql( + "CALL %s.system.migrate(table => '%s', backup_table_name => '%s')", + catalogName, tableName, backupTableName); + + assertThat(result).isEqualTo(1L); + String dbName = tableName.split("\\.")[0]; + assertThat(spark.catalog().tableExists(dbName + "." + backupTableName)).isTrue(); + } + + @TestTemplate + public void testMigrateWithInvalidMetricsConfig() throws IOException { + assumeThat(catalogName).isEqualToIgnoringCase("spark_catalog"); + + String location = Files.createTempDirectory(temp, "junit").toFile().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + tableName, location); + + assertThatThrownBy( + () -> { + String props = "map('write.metadata.metrics.column.x', 'X')"; + sql("CALL %s.system.migrate('%s', %s)", catalogName, tableName, props); + }) + .isInstanceOf(ValidationException.class) + .hasMessageStartingWith("Invalid metrics config"); + } + + @TestTemplate + public void testMigrateWithConflictingProps() throws IOException { + assumeThat(catalogName).isEqualToIgnoringCase("spark_catalog"); + + String location = Files.createTempDirectory(temp, "junit").toFile().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + tableName, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Object result = + scalarSql("CALL %s.system.migrate('%s', map('migrated', 'false'))", catalogName, tableName); + assertThat(result).as("Should have added one file").isEqualTo(1L); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s", tableName)); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.properties()).containsEntry("migrated", "true"); + } + + @TestTemplate + public void testInvalidMigrateCases() { + assertThatThrownBy(() -> sql("CALL %s.system.migrate()", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessage("Missing required parameters: [table]"); + + assertThatThrownBy(() -> sql("CALL %s.system.migrate(map('foo','bar'))", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith("Wrong arg type for table"); + + assertThatThrownBy(() -> sql("CALL %s.system.migrate('')", catalogName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot handle an empty identifier for argument table"); + } + + @TestTemplate + public void testMigratePartitionWithSpecialCharacter() throws IOException { + assumeThat(catalogName).isEqualToIgnoringCase("spark_catalog"); + String location = Files.createTempDirectory(temp, "junit").toFile().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string, dt date) USING parquet " + + "PARTITIONED BY (data, dt) LOCATION '%s'", + tableName, location); + sql("INSERT INTO TABLE %s VALUES (1, '2023/05/30', date '2023-05-30')", tableName); + Object result = scalarSql("CALL %s.system.migrate('%s')", catalogName, tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "2023/05/30", java.sql.Date.valueOf("2023-05-30"))), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void testMigrateEmptyPartitionedTable() throws Exception { + assumeThat(catalogName).isEqualToIgnoringCase("spark_catalog"); + String location = Files.createTempDirectory(temp, "junit").toFile().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet PARTITIONED BY (id) LOCATION '%s'", + tableName, location); + Object result = scalarSql("CALL %s.system.migrate('%s')", catalogName, tableName); + assertThat(result).isEqualTo(0L); + } + + @TestTemplate + public void testMigrateEmptyTable() throws Exception { + assumeThat(catalogName).isEqualToIgnoringCase("spark_catalog"); + String location = Files.createTempDirectory(temp, "junit").toFile().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + tableName, location); + Object result = scalarSql("CALL %s.system.migrate('%s')", catalogName, tableName); + assertThat(result).isEqualTo(0L); + } + + @TestTemplate + public void testMigrateWithParallelism() throws IOException { + assumeThat(catalogName).isEqualToIgnoringCase("spark_catalog"); + + String location = Files.createTempDirectory(temp, "junit").toFile().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + tableName, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + List result = + sql("CALL %s.system.migrate(table => '%s', parallelism => %d)", catalogName, tableName, 2); + assertEquals("Procedure output must match", ImmutableList.of(row(2L)), result); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(2L, "b")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void testMigrateWithInvalidParallelism() throws IOException { + assumeThat(catalogName).isEqualToIgnoringCase("spark_catalog"); + + String location = Files.createTempDirectory(temp, "junit").toFile().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + tableName, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + assertThatThrownBy( + () -> + sql( + "CALL %s.system.migrate(table => '%s', parallelism => %d)", + catalogName, tableName, -1)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Parallelism should be larger than 0"); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestPublishChangesProcedure.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestPublishChangesProcedure.java new file mode 100644 index 000000000000..6284d88a1550 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestPublishChangesProcedure.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.TableProperties.WRITE_AUDIT_PUBLISH_ENABLED; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +import java.util.List; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestPublishChangesProcedure extends ExtensionsTestBase { + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testApplyWapChangesUsingPositionalArgs() { + String wapId = "wap_id_1"; + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'true')", tableName, WRITE_AUDIT_PUBLISH_ENABLED); + + spark.conf().set("spark.wap.id", wapId); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should not see rows from staged snapshot", + ImmutableList.of(), + sql("SELECT * FROM %s", tableName)); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot wapSnapshot = Iterables.getOnlyElement(table.snapshots()); + + List output = + sql("CALL %s.system.publish_changes('%s', '%s')", catalogName, tableIdent, wapId); + + table.refresh(); + + Snapshot currentSnapshot = table.currentSnapshot(); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(wapSnapshot.snapshotId(), currentSnapshot.snapshotId())), + output); + + assertEquals( + "Apply of WAP changes must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s", tableName)); + } + + @TestTemplate + public void testApplyWapChangesUsingNamedArgs() { + String wapId = "wap_id_1"; + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'true')", tableName, WRITE_AUDIT_PUBLISH_ENABLED); + + spark.conf().set("spark.wap.id", wapId); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should not see rows from staged snapshot", + ImmutableList.of(), + sql("SELECT * FROM %s", tableName)); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot wapSnapshot = Iterables.getOnlyElement(table.snapshots()); + + List output = + sql( + "CALL %s.system.publish_changes(wap_id => '%s', table => '%s')", + catalogName, wapId, tableIdent); + + table.refresh(); + + Snapshot currentSnapshot = table.currentSnapshot(); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(wapSnapshot.snapshotId(), currentSnapshot.snapshotId())), + output); + + assertEquals( + "Apply of WAP changes must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s", tableName)); + } + + @TestTemplate + public void testApplyWapChangesRefreshesRelationCache() { + String wapId = "wap_id_1"; + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'true')", tableName, WRITE_AUDIT_PUBLISH_ENABLED); + + Dataset query = spark.sql("SELECT * FROM " + tableName + " WHERE id = 1"); + query.createOrReplaceTempView("tmp"); + + spark.sql("CACHE TABLE tmp"); + + assertEquals("View should not produce rows", ImmutableList.of(), sql("SELECT * FROM tmp")); + + spark.conf().set("spark.wap.id", wapId); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should not see rows from staged snapshot", + ImmutableList.of(), + sql("SELECT * FROM %s", tableName)); + + sql("CALL %s.system.publish_changes('%s', '%s')", catalogName, tableIdent, wapId); + + assertEquals( + "Apply of WAP changes should be visible", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM tmp")); + + sql("UNCACHE TABLE tmp"); + } + + @TestTemplate + public void testApplyInvalidWapId() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + assertThatThrownBy( + () -> sql("CALL %s.system.publish_changes('%s', 'not_valid')", catalogName, tableIdent)) + .isInstanceOf(ValidationException.class) + .hasMessage("Cannot apply unknown WAP ID 'not_valid'"); + } + + @TestTemplate + public void testInvalidApplyWapChangesCases() { + assertThatThrownBy( + () -> + sql("CALL %s.system.publish_changes('n', table => 't', 'not_valid')", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessage("Named and positional arguments cannot be mixed"); + + assertThatThrownBy( + () -> sql("CALL %s.custom.publish_changes('n', 't', 'not_valid')", catalogName)) + .isInstanceOf(ParseException.class) + .satisfies( + exception -> { + ParseException parseException = (ParseException) exception; + assertThat(parseException.getErrorClass()).isEqualTo("PARSE_SYNTAX_ERROR"); + assertThat(parseException.getMessageParameters().get("error")).isEqualTo("'CALL'"); + }); + + assertThatThrownBy(() -> sql("CALL %s.system.publish_changes('t')", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessage("Missing required parameters: [wap_id]"); + + assertThatThrownBy(() -> sql("CALL %s.system.publish_changes('', 'not_valid')", catalogName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot handle an empty identifier for argument table"); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRegisterTableProcedure.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRegisterTableProcedure.java new file mode 100644 index 000000000000..3047dccd959b --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRegisterTableProcedure.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Table; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.apache.spark.sql.functions; +import org.apache.spark.sql.types.DataTypes; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestRegisterTableProcedure extends ExtensionsTestBase { + + private String targetName; + + @BeforeEach + public void setTargetName() { + targetName = tableName("register_table"); + } + + @AfterEach + public void dropTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS %s", targetName); + } + + @TestTemplate + public void testRegisterTable() throws NoSuchTableException, ParseException { + long numRows = 1000; + + sql("CREATE TABLE %s (id int, data string) using ICEBERG", tableName); + spark + .range(0, numRows) + .withColumn("data", functions.col("id").cast(DataTypes.StringType)) + .writeTo(tableName) + .append(); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + long originalFileCount = (long) scalarSql("SELECT COUNT(*) from %s.files", tableName); + long currentSnapshotId = table.currentSnapshot().snapshotId(); + String metadataJson = + (((HasTableOperations) table).operations()).current().metadataFileLocation(); + + List result = + sql("CALL %s.system.register_table('%s', '%s')", catalogName, targetName, metadataJson); + assertThat(result.get(0)[0]).as("Current Snapshot is not correct").isEqualTo(currentSnapshotId); + + List original = sql("SELECT * FROM %s", tableName); + List registered = sql("SELECT * FROM %s", targetName); + assertEquals("Registered table rows should match original table rows", original, registered); + assertThat(result.get(0)[1]) + .as("Should have the right row count in the procedure result") + .isEqualTo(numRows); + assertThat(result.get(0)[2]) + .as("Should have the right datafile count in the procedure result") + .isEqualTo(originalFileCount); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRemoveOrphanFilesProcedure.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRemoveOrphanFilesProcedure.java new file mode 100644 index 000000000000..13836201ba81 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRemoveOrphanFilesProcedure.java @@ -0,0 +1,749 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.TableProperties.GC_ENABLED; +import static org.apache.iceberg.TableProperties.WRITE_AUDIT_PUBLISH_ENABLED; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.io.File; +import java.io.IOException; +import java.net.URI; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.sql.Timestamp; +import java.time.Instant; +import java.util.List; +import java.util.UUID; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.Files; +import org.apache.iceberg.GenericBlobMetadata; +import org.apache.iceberg.GenericStatisticsFile; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.PartitionStatisticsFile; +import org.apache.iceberg.ReachableFileUtil; +import org.apache.iceberg.StatisticsFile; +import org.apache.iceberg.Table; +import org.apache.iceberg.Transaction; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.puffin.Blob; +import org.apache.iceberg.puffin.Puffin; +import org.apache.iceberg.puffin.PuffinWriter; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.spark.source.FilePathLastModifiedRecord; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestRemoveOrphanFilesProcedure extends ExtensionsTestBase { + + @AfterEach + public void removeTable() { + sql("DROP TABLE IF EXISTS %s PURGE", tableName); + sql("DROP TABLE IF EXISTS p PURGE"); + } + + @TestTemplate + public void testRemoveOrphanFilesInEmptyTable() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + List output = + sql("CALL %s.system.remove_orphan_files('%s')", catalogName, tableIdent); + assertEquals("Should be no orphan files", ImmutableList.of(), output); + + assertEquals("Should have no rows", ImmutableList.of(), sql("SELECT * FROM %s", tableName)); + } + + @TestTemplate + public void testRemoveOrphanFilesInDataFolder() throws IOException { + if (catalogName.equals("testhadoop")) { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + } else { + // give a fresh location to Hive tables as Spark will not clean up the table location + // correctly while dropping tables through spark_catalog + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg LOCATION '%s'", + tableName, java.nio.file.Files.createTempDirectory(temp, "junit")); + } + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + String metadataLocation = table.location() + "/metadata"; + String dataLocation = table.location() + "/data"; + + // produce orphan files in the data location using parquet + sql("CREATE TABLE p (id bigint) USING parquet LOCATION '%s'", dataLocation); + sql("INSERT INTO TABLE p VALUES (1)"); + + // wait to ensure files are old enough + waitUntilAfter(System.currentTimeMillis()); + + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + + // check for orphans in the metadata folder + List output1 = + sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "older_than => TIMESTAMP '%s'," + + "location => '%s')", + catalogName, tableIdent, currentTimestamp, metadataLocation); + assertEquals("Should be no orphan files in the metadata folder", ImmutableList.of(), output1); + + // check for orphans in the table location + List output2 = + sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "older_than => TIMESTAMP '%s')", + catalogName, tableIdent, currentTimestamp); + assertThat(output2).as("Should be orphan files in the data folder").hasSize(1); + + // the previous call should have deleted all orphan files + List output3 = + sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "older_than => TIMESTAMP '%s')", + catalogName, tableIdent, currentTimestamp); + assertThat(output3).as("Should be no more orphan files in the data folder").hasSize(0); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(2L, "b")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void testRemoveOrphanFilesDryRun() throws IOException { + if (catalogName.equals("testhadoop")) { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + } else { + // give a fresh location to Hive tables as Spark will not clean up the table location + // correctly while dropping tables through spark_catalog + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg LOCATION '%s'", + tableName, java.nio.file.Files.createTempDirectory(temp, "junit")); + } + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + // produce orphan files in the table location using parquet + sql("CREATE TABLE p (id bigint) USING parquet LOCATION '%s'", table.location()); + sql("INSERT INTO TABLE p VALUES (1)"); + + // wait to ensure files are old enough + waitUntilAfter(System.currentTimeMillis()); + + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + + // check for orphans without deleting + List output1 = + sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "older_than => TIMESTAMP '%s'," + + "dry_run => true)", + catalogName, tableIdent, currentTimestamp); + assertThat(output1).as("Should be one orphan files").hasSize(1); + + // actually delete orphans + List output2 = + sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "older_than => TIMESTAMP '%s')", + catalogName, tableIdent, currentTimestamp); + assertThat(output2).as("Should be one orphan files").hasSize(1); + + // the previous call should have deleted all orphan files + List output3 = + sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "older_than => TIMESTAMP '%s')", + catalogName, tableIdent, currentTimestamp); + assertThat(output3).as("Should be no more orphan files").hasSize(0); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(2L, "b")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void testRemoveOrphanFilesGCDisabled() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'false')", tableName, GC_ENABLED); + + assertThatThrownBy( + () -> sql("CALL %s.system.remove_orphan_files('%s')", catalogName, tableIdent)) + .isInstanceOf(ValidationException.class) + .hasMessage( + "Cannot delete orphan files: GC is disabled (deleting files may corrupt other tables)"); + + // reset the property to enable the table purging in removeTable. + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'true')", tableName, GC_ENABLED); + } + + @TestTemplate + public void testRemoveOrphanFilesWap() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'true')", tableName, WRITE_AUDIT_PUBLISH_ENABLED); + + spark.conf().set("spark.wap.id", "1"); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should not see rows from staged snapshot", + ImmutableList.of(), + sql("SELECT * FROM %s", tableName)); + + List output = + sql("CALL %s.system.remove_orphan_files('%s')", catalogName, tableIdent); + assertEquals("Should be no orphan files", ImmutableList.of(), output); + } + + @TestTemplate + public void testInvalidRemoveOrphanFilesCases() { + assertThatThrownBy( + () -> sql("CALL %s.system.remove_orphan_files('n', table => 't')", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessage("Named and positional arguments cannot be mixed"); + + assertThatThrownBy(() -> sql("CALL %s.custom.remove_orphan_files('n', 't')", catalogName)) + .isInstanceOf(NoSuchProcedureException.class) + .hasMessage("Procedure custom.remove_orphan_files not found"); + + assertThatThrownBy(() -> sql("CALL %s.system.remove_orphan_files()", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessage("Missing required parameters: [table]"); + + assertThatThrownBy(() -> sql("CALL %s.system.remove_orphan_files('n', 2.2)", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith("Wrong arg type for older_than"); + + assertThatThrownBy(() -> sql("CALL %s.system.remove_orphan_files('')", catalogName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot handle an empty identifier for argument table"); + } + + @TestTemplate + public void testConcurrentRemoveOrphanFiles() throws IOException { + if (catalogName.equals("testhadoop")) { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + } else { + // give a fresh location to Hive tables as Spark will not clean up the table location + // correctly while dropping tables through spark_catalog + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg LOCATION '%s'", + tableName, java.nio.file.Files.createTempDirectory(temp, "junit")); + } + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + String dataLocation = table.location() + "/data"; + + // produce orphan files in the data location using parquet + sql("CREATE TABLE p (id bigint) USING parquet LOCATION '%s'", dataLocation); + sql("INSERT INTO TABLE p VALUES (1)"); + sql("INSERT INTO TABLE p VALUES (10)"); + sql("INSERT INTO TABLE p VALUES (100)"); + sql("INSERT INTO TABLE p VALUES (1000)"); + + // wait to ensure files are old enough + waitUntilAfter(System.currentTimeMillis()); + + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + + // check for orphans in the table location + List output = + sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "max_concurrent_deletes => %s," + + "older_than => TIMESTAMP '%s')", + catalogName, tableIdent, 4, currentTimestamp); + assertThat(output).as("Should be orphan files in the data folder").hasSize(4); + + // the previous call should have deleted all orphan files + List output3 = + sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "max_concurrent_deletes => %s," + + "older_than => TIMESTAMP '%s')", + catalogName, tableIdent, 4, currentTimestamp); + assertThat(output3).as("Should be no more orphan files in the data folder").hasSize(0); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(2L, "b")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void testConcurrentRemoveOrphanFilesWithInvalidInput() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + assertThatThrownBy( + () -> + sql( + "CALL %s.system.remove_orphan_files(table => '%s', max_concurrent_deletes => %s)", + catalogName, tableIdent, 0)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("max_concurrent_deletes should have value > 0, value: 0"); + + assertThatThrownBy( + () -> + sql( + "CALL %s.system.remove_orphan_files(table => '%s', max_concurrent_deletes => %s)", + catalogName, tableIdent, -1)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("max_concurrent_deletes should have value > 0, value: -1"); + + String tempViewName = "file_list_test"; + spark.emptyDataFrame().createOrReplaceTempView(tempViewName); + + assertThatThrownBy( + () -> + sql( + "CALL %s.system.remove_orphan_files(table => '%s', file_list_view => '%s')", + catalogName, tableIdent, tempViewName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("No such struct field `file_path` in "); + + spark + .createDataset(Lists.newArrayList(), Encoders.tuple(Encoders.INT(), Encoders.TIMESTAMP())) + .toDF("file_path", "last_modified") + .createOrReplaceTempView(tempViewName); + + assertThatThrownBy( + () -> + sql( + "CALL %s.system.remove_orphan_files(table => '%s', file_list_view => '%s')", + catalogName, tableIdent, tempViewName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Invalid file_path column: IntegerType is not a string"); + + spark + .createDataset(Lists.newArrayList(), Encoders.tuple(Encoders.STRING(), Encoders.STRING())) + .toDF("file_path", "last_modified") + .createOrReplaceTempView(tempViewName); + + assertThatThrownBy( + () -> + sql( + "CALL %s.system.remove_orphan_files(table => '%s', file_list_view => '%s')", + catalogName, tableIdent, tempViewName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Invalid last_modified column: StringType is not a timestamp"); + } + + @TestTemplate + public void testRemoveOrphanFilesWithDeleteFiles() throws Exception { + sql( + "CREATE TABLE %s (id int, data string) USING iceberg TBLPROPERTIES" + + "('format-version'='2', 'write.delete.mode'='merge-on-read')", + tableName); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c"), + new SimpleRecord(4, "d")); + spark + .createDataset(records, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + sql("DELETE FROM %s WHERE id=1", tableName); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + assertThat(TestHelpers.deleteManifests(table)).as("Should have 1 delete manifest").hasSize(1); + assertThat(TestHelpers.deleteFiles(table)).as("Should have 1 delete file").hasSize(1); + Path deleteManifestPath = new Path(TestHelpers.deleteManifests(table).iterator().next().path()); + Path deleteFilePath = + new Path(String.valueOf(TestHelpers.deleteFiles(table).iterator().next().path())); + + // wait to ensure files are old enough + waitUntilAfter(System.currentTimeMillis()); + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + + // delete orphans + List output = + sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "older_than => TIMESTAMP '%s')", + catalogName, tableIdent, currentTimestamp); + assertThat(output).as("Should be no orphan files").hasSize(0); + + FileSystem localFs = FileSystem.getLocal(new Configuration()); + assertThat(localFs.exists(deleteManifestPath)) + .as("Delete manifest should still exist") + .isTrue(); + assertThat(localFs.exists(deleteFilePath)).as("Delete file should still exist").isTrue(); + + records.remove(new SimpleRecord(1, "a")); + Dataset resultDF = spark.read().format("iceberg").load(tableName); + List actualRecords = + resultDF.as(Encoders.bean(SimpleRecord.class)).collectAsList(); + assertThat(actualRecords).as("Rows must match").isEqualTo(records); + } + + @TestTemplate + public void testRemoveOrphanFilesWithStatisticFiles() throws Exception { + sql( + "CREATE TABLE %s USING iceberg " + + "TBLPROPERTIES('format-version'='2') " + + "AS SELECT 10 int, 'abc' data", + tableName); + Table table = Spark3Util.loadIcebergTable(spark, tableName); + + String statsFileName = "stats-file-" + UUID.randomUUID(); + File statsLocation = + new File(new URI(table.location())) + .toPath() + .resolve("data") + .resolve(statsFileName) + .toFile(); + StatisticsFile statisticsFile; + try (PuffinWriter puffinWriter = Puffin.write(Files.localOutput(statsLocation)).build()) { + long snapshotId = table.currentSnapshot().snapshotId(); + long snapshotSequenceNumber = table.currentSnapshot().sequenceNumber(); + puffinWriter.add( + new Blob( + "some-blob-type", + ImmutableList.of(1), + snapshotId, + snapshotSequenceNumber, + ByteBuffer.wrap("blob content".getBytes(StandardCharsets.UTF_8)))); + puffinWriter.finish(); + statisticsFile = + new GenericStatisticsFile( + snapshotId, + statsLocation.toString(), + puffinWriter.fileSize(), + puffinWriter.footerSize(), + puffinWriter.writtenBlobsMetadata().stream() + .map(GenericBlobMetadata::from) + .collect(ImmutableList.toImmutableList())); + } + + Transaction transaction = table.newTransaction(); + transaction + .updateStatistics() + .setStatistics(statisticsFile.snapshotId(), statisticsFile) + .commit(); + transaction.commitTransaction(); + + // wait to ensure files are old enough + waitUntilAfter(System.currentTimeMillis()); + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + + List output = + sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "older_than => TIMESTAMP '%s')", + catalogName, tableIdent, currentTimestamp); + assertThat(output).as("Should be no orphan files").isEmpty(); + + assertThat(statsLocation).exists(); + assertThat(statsLocation).hasSize(statisticsFile.fileSizeInBytes()); + + transaction = table.newTransaction(); + transaction.updateStatistics().removeStatistics(statisticsFile.snapshotId()).commit(); + transaction.commitTransaction(); + + output = + sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "older_than => TIMESTAMP '%s')", + catalogName, tableIdent, currentTimestamp); + assertThat(output).as("Should be orphan files").hasSize(1); + assertThat(Iterables.getOnlyElement(output)) + .as("Deleted files") + .containsExactly(statsLocation.toURI().toString()); + assertThat(statsLocation).doesNotExist(); + } + + @TestTemplate + public void testRemoveOrphanFilesWithPartitionStatisticFiles() throws Exception { + sql( + "CREATE TABLE %s USING iceberg " + + "TBLPROPERTIES('format-version'='2') " + + "AS SELECT 10 int, 'abc' data", + tableName); + Table table = Spark3Util.loadIcebergTable(spark, tableName); + + String partitionStatsLocation = ProcedureUtil.statsFileLocation(table.location()); + PartitionStatisticsFile partitionStatisticsFile = + ProcedureUtil.writePartitionStatsFile( + table.currentSnapshot().snapshotId(), partitionStatsLocation, table.io()); + + commitPartitionStatsTxn(table, partitionStatisticsFile); + + // wait to ensure files are old enough + waitUntilAfter(System.currentTimeMillis()); + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + + List output = + sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "older_than => TIMESTAMP '%s')", + catalogName, tableIdent, currentTimestamp); + assertThat(output).as("Should be no orphan files").isEmpty(); + + assertThat(new File(partitionStatsLocation)).as("partition stats file should exist").exists(); + + removePartitionStatsTxn(table, partitionStatisticsFile); + + output = + sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "older_than => TIMESTAMP '%s')", + catalogName, tableIdent, currentTimestamp); + assertThat(output).as("Should be orphan files").hasSize(1); + assertThat(Iterables.getOnlyElement(output)) + .as("Deleted files") + .containsExactly("file:" + partitionStatsLocation); + assertThat(new File(partitionStatsLocation)) + .as("partition stats file should be deleted") + .doesNotExist(); + } + + private static void removePartitionStatsTxn( + Table table, PartitionStatisticsFile partitionStatisticsFile) { + Transaction transaction = table.newTransaction(); + transaction + .updatePartitionStatistics() + .removePartitionStatistics(partitionStatisticsFile.snapshotId()) + .commit(); + transaction.commitTransaction(); + } + + private static void commitPartitionStatsTxn( + Table table, PartitionStatisticsFile partitionStatisticsFile) { + Transaction transaction = table.newTransaction(); + transaction + .updatePartitionStatistics() + .setPartitionStatistics(partitionStatisticsFile) + .commit(); + transaction.commitTransaction(); + } + + @TestTemplate + public void testRemoveOrphanFilesProcedureWithPrefixMode() + throws NoSuchTableException, ParseException, IOException { + if (catalogName.equals("testhadoop")) { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + } else { + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg LOCATION '%s'", + tableName, java.nio.file.Files.createTempDirectory(temp, "junit")); + } + Table table = Spark3Util.loadIcebergTable(spark, tableName); + String location = table.location(); + Path originalPath = new Path(location); + + URI uri = originalPath.toUri(); + Path newParentPath = new Path("file1", uri.getAuthority(), uri.getPath()); + + DataFile dataFile1 = + DataFiles.builder(PartitionSpec.unpartitioned()) + .withPath(new Path(newParentPath, "path/to/data-a.parquet").toString()) + .withFileSizeInBytes(10) + .withRecordCount(1) + .build(); + DataFile dataFile2 = + DataFiles.builder(PartitionSpec.unpartitioned()) + .withPath(new Path(newParentPath, "path/to/data-b.parquet").toString()) + .withFileSizeInBytes(10) + .withRecordCount(1) + .build(); + + table.newFastAppend().appendFile(dataFile1).appendFile(dataFile2).commit(); + + Timestamp lastModifiedTimestamp = new Timestamp(10000); + + List allFiles = + Lists.newArrayList( + new FilePathLastModifiedRecord( + new Path(originalPath, "path/to/data-a.parquet").toString(), lastModifiedTimestamp), + new FilePathLastModifiedRecord( + new Path(originalPath, "path/to/data-b.parquet").toString(), lastModifiedTimestamp), + new FilePathLastModifiedRecord( + ReachableFileUtil.versionHintLocation(table), lastModifiedTimestamp)); + + for (String file : ReachableFileUtil.metadataFileLocations(table, true)) { + allFiles.add(new FilePathLastModifiedRecord(file, lastModifiedTimestamp)); + } + + for (ManifestFile manifest : TestHelpers.dataManifests(table)) { + allFiles.add(new FilePathLastModifiedRecord(manifest.path(), lastModifiedTimestamp)); + } + + Dataset compareToFileList = + spark + .createDataFrame(allFiles, FilePathLastModifiedRecord.class) + .withColumnRenamed("filePath", "file_path") + .withColumnRenamed("lastModified", "last_modified"); + String fileListViewName = "files_view"; + compareToFileList.createOrReplaceTempView(fileListViewName); + List orphanFiles = + sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "equal_schemes => map('file1', 'file')," + + "file_list_view => '%s')", + catalogName, tableIdent, fileListViewName); + assertThat(orphanFiles).isEmpty(); + + // Test with no equal schemes + assertThatThrownBy( + () -> + sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "file_list_view => '%s')", + catalogName, tableIdent, fileListViewName)) + .isInstanceOf(ValidationException.class) + .hasMessageEndingWith("Conflicting authorities/schemes: [(file1, file)]."); + + // Drop table in afterEach has purge and fails due to invalid scheme "file1" used in this test + // Dropping the table here + sql("DROP TABLE %s", tableName); + } + + @TestTemplate + public void testRemoveOrphanFilesProcedureWithEqualAuthorities() + throws NoSuchTableException, ParseException, IOException { + if (catalogName.equals("testhadoop")) { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + } else { + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg LOCATION '%s'", + tableName, java.nio.file.Files.createTempDirectory(temp, "junit")); + } + Table table = Spark3Util.loadIcebergTable(spark, tableName); + Path originalPath = new Path(table.location()); + + URI uri = originalPath.toUri(); + String originalAuthority = uri.getAuthority() == null ? "" : uri.getAuthority(); + Path newParentPath = new Path(uri.getScheme(), "localhost", uri.getPath()); + + DataFile dataFile1 = + DataFiles.builder(PartitionSpec.unpartitioned()) + .withPath(new Path(newParentPath, "path/to/data-a.parquet").toString()) + .withFileSizeInBytes(10) + .withRecordCount(1) + .build(); + DataFile dataFile2 = + DataFiles.builder(PartitionSpec.unpartitioned()) + .withPath(new Path(newParentPath, "path/to/data-b.parquet").toString()) + .withFileSizeInBytes(10) + .withRecordCount(1) + .build(); + + table.newFastAppend().appendFile(dataFile1).appendFile(dataFile2).commit(); + + Timestamp lastModifiedTimestamp = new Timestamp(10000); + + List allFiles = + Lists.newArrayList( + new FilePathLastModifiedRecord( + new Path(originalPath, "path/to/data-a.parquet").toString(), lastModifiedTimestamp), + new FilePathLastModifiedRecord( + new Path(originalPath, "path/to/data-b.parquet").toString(), lastModifiedTimestamp), + new FilePathLastModifiedRecord( + ReachableFileUtil.versionHintLocation(table), lastModifiedTimestamp)); + + for (String file : ReachableFileUtil.metadataFileLocations(table, true)) { + allFiles.add(new FilePathLastModifiedRecord(file, lastModifiedTimestamp)); + } + + for (ManifestFile manifest : TestHelpers.dataManifests(table)) { + allFiles.add(new FilePathLastModifiedRecord(manifest.path(), lastModifiedTimestamp)); + } + + Dataset compareToFileList = + spark + .createDataFrame(allFiles, FilePathLastModifiedRecord.class) + .withColumnRenamed("filePath", "file_path") + .withColumnRenamed("lastModified", "last_modified"); + String fileListViewName = "files_view"; + compareToFileList.createOrReplaceTempView(fileListViewName); + List orphanFiles = + sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "equal_authorities => map('localhost', '%s')," + + "file_list_view => '%s')", + catalogName, tableIdent, originalAuthority, fileListViewName); + assertThat(orphanFiles).isEmpty(); + + // Test with no equal authorities + assertThatThrownBy( + () -> + sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "file_list_view => '%s')", + catalogName, tableIdent, fileListViewName)) + .isInstanceOf(ValidationException.class) + .hasMessageEndingWith("Conflicting authorities/schemes: [(localhost, null)]."); + + // Drop table in afterEach has purge and fails due to invalid authority "localhost" + // Dropping the table here + sql("DROP TABLE %s", tableName); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestReplaceBranch.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestReplaceBranch.java new file mode 100644 index 000000000000..414d5abc8792 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestReplaceBranch.java @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.List; +import java.util.concurrent.TimeUnit; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.SnapshotRef; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestReplaceBranch extends ExtensionsTestBase { + + private static final String[] TIME_UNITS = {"DAYS", "HOURS", "MINUTES"}; + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + public static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.SPARK.catalogName(), + SparkCatalogConfig.SPARK.implementation(), + SparkCatalogConfig.SPARK.properties() + } + }; + } + + @AfterEach + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testReplaceBranchFailsForTag() throws NoSuchTableException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + String tagName = "tag1"; + + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + Table table = validationCatalog.loadTable(tableIdent); + long first = table.currentSnapshot().snapshotId(); + table.manageSnapshots().createTag(tagName, first).commit(); + df.writeTo(tableName).append(); + long second = table.currentSnapshot().snapshotId(); + + assertThatThrownBy( + () -> + sql( + "ALTER TABLE %s REPLACE BRANCH %s AS OF VERSION %d", + tableName, tagName, second)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Ref tag1 is a tag not a branch"); + } + + @TestTemplate + public void testReplaceBranch() throws NoSuchTableException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + long first = table.currentSnapshot().snapshotId(); + String branchName = "b1"; + long expectedMaxRefAgeMs = 1000; + int expectedMinSnapshotsToKeep = 2; + long expectedMaxSnapshotAgeMs = 1000; + table + .manageSnapshots() + .createBranch(branchName, first) + .setMaxRefAgeMs(branchName, expectedMaxRefAgeMs) + .setMinSnapshotsToKeep(branchName, expectedMinSnapshotsToKeep) + .setMaxSnapshotAgeMs(branchName, expectedMaxSnapshotAgeMs) + .commit(); + + df.writeTo(tableName).append(); + long second = table.currentSnapshot().snapshotId(); + + sql("ALTER TABLE %s REPLACE BRANCH %s AS OF VERSION %d", tableName, branchName, second); + + table.refresh(); + SnapshotRef ref = table.refs().get(branchName); + assertThat(ref).isNotNull(); + assertThat(ref.snapshotId()).isEqualTo(second); + assertThat(ref.minSnapshotsToKeep().intValue()).isEqualTo(expectedMinSnapshotsToKeep); + assertThat(ref.maxSnapshotAgeMs().longValue()).isEqualTo(expectedMaxSnapshotAgeMs); + assertThat(ref.maxRefAgeMs().longValue()).isEqualTo(expectedMaxRefAgeMs); + } + + @TestTemplate + public void testReplaceBranchDoesNotExist() throws NoSuchTableException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + Table table = validationCatalog.loadTable(tableIdent); + + assertThatThrownBy( + () -> + sql( + "ALTER TABLE %s REPLACE BRANCH %s AS OF VERSION %d", + tableName, "someBranch", table.currentSnapshot().snapshotId())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Branch does not exist: someBranch"); + } + + @TestTemplate + public void testReplaceBranchWithRetain() throws NoSuchTableException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + long first = table.currentSnapshot().snapshotId(); + String branchName = "b1"; + table.manageSnapshots().createBranch(branchName, first).commit(); + df.writeTo(tableName).append(); + long second = table.currentSnapshot().snapshotId(); + + long maxRefAge = 10; + for (String timeUnit : TIME_UNITS) { + sql( + "ALTER TABLE %s REPLACE BRANCH %s AS OF VERSION %d RETAIN %d %s", + tableName, branchName, second, maxRefAge, timeUnit); + + table.refresh(); + SnapshotRef ref = table.refs().get(branchName); + assertThat(ref).isNotNull(); + assertThat(ref.snapshotId()).isEqualTo(second); + assertThat(ref.minSnapshotsToKeep()).isNull(); + assertThat(ref.maxSnapshotAgeMs()).isNull(); + assertThat(ref.maxRefAgeMs().longValue()) + .isEqualTo(TimeUnit.valueOf(timeUnit).toMillis(maxRefAge)); + } + } + + @TestTemplate + public void testReplaceBranchWithSnapshotRetention() throws NoSuchTableException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + String branchName = "b1"; + Table table = validationCatalog.loadTable(tableIdent); + long first = table.currentSnapshot().snapshotId(); + table.manageSnapshots().createBranch(branchName, first).commit(); + df.writeTo(tableName).append(); + long second = table.currentSnapshot().snapshotId(); + + Integer minSnapshotsToKeep = 2; + long maxSnapshotAge = 2; + Long maxRefAgeMs = table.refs().get(branchName).maxRefAgeMs(); + for (String timeUnit : TIME_UNITS) { + sql( + "ALTER TABLE %s REPLACE BRANCH %s AS OF VERSION %d WITH SNAPSHOT RETENTION %d SNAPSHOTS %d %s", + tableName, branchName, second, minSnapshotsToKeep, maxSnapshotAge, timeUnit); + + table.refresh(); + SnapshotRef ref = table.refs().get(branchName); + assertThat(ref).isNotNull(); + assertThat(ref.snapshotId()).isEqualTo(second); + assertThat(ref.minSnapshotsToKeep()).isEqualTo(minSnapshotsToKeep); + assertThat(ref.maxSnapshotAgeMs().longValue()) + .isEqualTo(TimeUnit.valueOf(timeUnit).toMillis(maxSnapshotAge)); + assertThat(ref.maxRefAgeMs()).isEqualTo(maxRefAgeMs); + } + } + + @TestTemplate + public void testReplaceBranchWithRetainAndSnapshotRetention() throws NoSuchTableException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + long first = table.currentSnapshot().snapshotId(); + String branchName = "b1"; + table.manageSnapshots().createBranch(branchName, first).commit(); + df.writeTo(tableName).append(); + long second = table.currentSnapshot().snapshotId(); + + Integer minSnapshotsToKeep = 2; + long maxSnapshotAge = 2; + long maxRefAge = 10; + for (String timeUnit : TIME_UNITS) { + sql( + "ALTER TABLE %s REPLACE BRANCH %s AS OF VERSION %d RETAIN %d %s WITH SNAPSHOT RETENTION %d SNAPSHOTS %d %s", + tableName, + branchName, + second, + maxRefAge, + timeUnit, + minSnapshotsToKeep, + maxSnapshotAge, + timeUnit); + + table.refresh(); + SnapshotRef ref = table.refs().get(branchName); + assertThat(ref).isNotNull(); + assertThat(ref.snapshotId()).isEqualTo(second); + assertThat(ref.minSnapshotsToKeep()).isEqualTo(minSnapshotsToKeep); + assertThat(ref.maxSnapshotAgeMs().longValue()) + .isEqualTo(TimeUnit.valueOf(timeUnit).toMillis(maxSnapshotAge)); + assertThat(ref.maxRefAgeMs().longValue()) + .isEqualTo(TimeUnit.valueOf(timeUnit).toMillis(maxRefAge)); + } + } + + @TestTemplate + public void testCreateOrReplace() throws NoSuchTableException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + long first = table.currentSnapshot().snapshotId(); + String branchName = "b1"; + df.writeTo(tableName).append(); + long second = table.currentSnapshot().snapshotId(); + table.manageSnapshots().createBranch(branchName, second).commit(); + + sql( + "ALTER TABLE %s CREATE OR REPLACE BRANCH %s AS OF VERSION %d", + tableName, branchName, first); + + table.refresh(); + SnapshotRef ref = table.refs().get(branchName); + assertThat(ref).isNotNull(); + assertThat(ref.snapshotId()).isEqualTo(first); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRequiredDistributionAndOrdering.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRequiredDistributionAndOrdering.java new file mode 100644 index 000000000000..fe1c38482bed --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRequiredDistributionAndOrdering.java @@ -0,0 +1,313 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.math.BigDecimal; +import java.util.List; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.spark.source.ThreeColumnRecord; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestRequiredDistributionAndOrdering extends ExtensionsTestBase { + + @AfterEach + public void dropTestTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testDefaultLocalSortWithBucketTransforms() throws NoSuchTableException { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, c1))", + tableName); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A")); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + // should insert a local sort by partition columns by default + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @TestTemplate + public void testPartitionColumnsArePrependedForRangeDistribution() throws NoSuchTableException { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, c1))", + tableName); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A")); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + // should automatically prepend partition columns to the ordering + sql("ALTER TABLE %s WRITE ORDERED BY c1, c2", tableName); + + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @TestTemplate + public void testSortOrderIncludesPartitionColumns() throws NoSuchTableException { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, c1))", + tableName); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A")); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + // should succeed with a correct sort order + sql("ALTER TABLE %s WRITE ORDERED BY bucket(2, c3), c1, c2", tableName); + + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @TestTemplate + public void testHashDistributionOnBucketedColumn() throws NoSuchTableException { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, c1))", + tableName); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A")); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + // should automatically prepend partition columns to the local ordering after hash distribution + sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION ORDERED BY c1, c2", tableName); + + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @TestTemplate + public void testDisabledDistributionAndOrdering() { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, c1))", + tableName); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A")); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + // should fail if ordering is disabled + assertThatThrownBy( + () -> + inputDF + .writeTo(tableName) + .option(SparkWriteOptions.USE_TABLE_DISTRIBUTION_AND_ORDERING, "false") + .option(SparkWriteOptions.FANOUT_ENABLED, "false") + .append()) + .cause() + .isInstanceOf(IllegalStateException.class) + .hasMessageStartingWith( + "Incoming records violate the writer assumption that records are clustered by spec " + + "and by partition within each spec. Either cluster the incoming records or switch to fanout writers."); + } + + @TestTemplate + public void testDefaultSortOnDecimalBucketedColumn() { + sql( + "CREATE TABLE %s (c1 INT, c2 DECIMAL(20, 2)) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, c2))", + tableName); + + sql("INSERT INTO %s VALUES (1, 20.2), (2, 40.2), (3, 60.2)", tableName); + + List expected = + ImmutableList.of( + row(1, new BigDecimal("20.20")), + row(2, new BigDecimal("40.20")), + row(3, new BigDecimal("60.20"))); + + assertEquals("Rows must match", expected, sql("SELECT * FROM %s ORDER BY c1", tableName)); + } + + @TestTemplate + public void testDefaultSortOnStringBucketedColumn() { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, c2))", + tableName); + + sql("INSERT INTO %s VALUES (1, 'A'), (2, 'B')", tableName); + + List expected = ImmutableList.of(row(1, "A"), row(2, "B")); + + assertEquals("Rows must match", expected, sql("SELECT * FROM %s ORDER BY c1", tableName)); + } + + @TestTemplate + public void testDefaultSortOnBinaryBucketedColumn() { + sql( + "CREATE TABLE %s (c1 INT, c2 Binary) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, c2))", + tableName); + + sql("INSERT INTO %s VALUES (1, X'A1B1'), (2, X'A2B2')", tableName); + + byte[] bytes1 = new byte[] {-95, -79}; + byte[] bytes2 = new byte[] {-94, -78}; + List expected = ImmutableList.of(row(1, bytes1), row(2, bytes2)); + + assertEquals("Rows must match", expected, sql("SELECT * FROM %s ORDER BY c1", tableName)); + } + + @TestTemplate + public void testDefaultSortOnDecimalTruncatedColumn() { + sql( + "CREATE TABLE %s (c1 INT, c2 DECIMAL(20, 2)) " + + "USING iceberg " + + "PARTITIONED BY (truncate(2, c2))", + tableName); + + sql("INSERT INTO %s VALUES (1, 20.2), (2, 40.2)", tableName); + + List expected = + ImmutableList.of(row(1, new BigDecimal("20.20")), row(2, new BigDecimal("40.20"))); + + assertEquals("Rows must match", expected, sql("SELECT * FROM %s ORDER BY c1", tableName)); + } + + @TestTemplate + public void testDefaultSortOnLongTruncatedColumn() { + sql( + "CREATE TABLE %s (c1 INT, c2 BIGINT) " + + "USING iceberg " + + "PARTITIONED BY (truncate(2, c2))", + tableName); + + sql("INSERT INTO %s VALUES (1, 22222222222222), (2, 444444444444)", tableName); + + List expected = ImmutableList.of(row(1, 22222222222222L), row(2, 444444444444L)); + + assertEquals("Rows must match", expected, sql("SELECT * FROM %s ORDER BY c1", tableName)); + } + + @TestTemplate + public void testRangeDistributionWithQuotedColumnNames() throws NoSuchTableException { + sql( + "CREATE TABLE %s (`c.1` INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, `c.1`))", + tableName); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A")); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = + ds.selectExpr("c1 as `c.1`", "c2", "c3").coalesce(1).sortWithinPartitions("`c.1`"); + + sql("ALTER TABLE %s WRITE ORDERED BY `c.1`, c2", tableName); + + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewriteDataFilesProcedure.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewriteDataFilesProcedure.java new file mode 100644 index 000000000000..860100583698 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewriteDataFilesProcedure.java @@ -0,0 +1,967 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assumptions.assumeThat; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.IntStream; +import org.apache.iceberg.CatalogProperties; +import org.apache.iceberg.EnvironmentContext; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.NamedReference; +import org.apache.iceberg.expressions.Zorder; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.ExtendedParser; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.SparkTableCache; +import org.apache.iceberg.spark.SystemFunctionPushDownHelper; +import org.apache.iceberg.spark.source.ThreeColumnRecord; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; +import org.apache.spark.sql.internal.SQLConf; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestRewriteDataFilesProcedure extends ExtensionsTestBase { + + private static final String QUOTED_SPECIAL_CHARS_TABLE_NAME = "`table:with.special:chars`"; + + @BeforeAll + public static void setupSpark() { + // disable AQE as tests assume that writes generate a particular number of files + spark.conf().set(SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), "false"); + } + + @AfterEach + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS %s", tableName(QUOTED_SPECIAL_CHARS_TABLE_NAME)); + } + + @TestTemplate + public void testZOrderSortExpression() { + List order = + ExtendedParser.parseSortOrder(spark, "c1, zorder(c2, c3)"); + assertThat(order).as("Should parse 2 order fields").hasSize(2); + assertThat(((NamedReference) order.get(0).term()).name()) + .as("First field should be a ref") + .isEqualTo("c1"); + assertThat(order.get(1).term()).as("Second field should be zorder").isInstanceOf(Zorder.class); + } + + @TestTemplate + public void testRewriteDataFilesInEmptyTable() { + createTable(); + List output = sql("CALL %s.system.rewrite_data_files('%s')", catalogName, tableIdent); + assertEquals("Procedure output must match", ImmutableList.of(row(0, 0, 0L, 0)), output); + } + + @TestTemplate + public void testRewriteDataFilesOnPartitionTable() { + createPartitionTable(); + // create 5 files for each partition (c2 = 'foo' and c2 = 'bar') + insertData(10); + List expectedRecords = currentData(); + + List output = + sql("CALL %s.system.rewrite_data_files(table => '%s')", catalogName, tableIdent); + + assertEquals( + "Action should rewrite 10 data files and add 2 data files (one per partition) ", + row(10, 2), + Arrays.copyOf(output.get(0), 2)); + // verify rewritten bytes separately + assertThat(output.get(0)).hasSize(4); + assertThat(output.get(0)[2]) + .isInstanceOf(Long.class) + .isEqualTo(Long.valueOf(snapshotSummary().get(SnapshotSummary.REMOVED_FILE_SIZE_PROP))); + + List actualRecords = currentData(); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + } + + @TestTemplate + public void testRewriteDataFilesOnNonPartitionTable() { + createTable(); + // create 10 files under non-partitioned table + insertData(10); + List expectedRecords = currentData(); + + List output = + sql("CALL %s.system.rewrite_data_files(table => '%s')", catalogName, tableIdent); + + assertEquals( + "Action should rewrite 10 data files and add 1 data files", + row(10, 1), + Arrays.copyOf(output.get(0), 2)); + // verify rewritten bytes separately + assertThat(output.get(0)).hasSize(4); + assertThat(output.get(0)[2]) + .isInstanceOf(Long.class) + .isEqualTo(Long.valueOf(snapshotSummary().get(SnapshotSummary.REMOVED_FILE_SIZE_PROP))); + + List actualRecords = currentData(); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + } + + @TestTemplate + public void testRewriteDataFilesWithOptions() { + createTable(); + // create 10 files under non-partitioned table + insertData(10); + List expectedRecords = currentData(); + + // set the min-input-files = 12, instead of default 5 to skip compacting the files. + List output = + sql( + "CALL %s.system.rewrite_data_files(table => '%s', options => map('min-input-files','12'))", + catalogName, tableIdent); + + assertEquals( + "Action should rewrite 0 data files and add 0 data files", + ImmutableList.of(row(0, 0, 0L, 0)), + output); + + List actualRecords = currentData(); + assertEquals("Data should not change", expectedRecords, actualRecords); + } + + @TestTemplate + public void testRewriteDataFilesWithSortStrategy() { + createTable(); + // create 10 files under non-partitioned table + insertData(10); + List expectedRecords = currentData(); + + // set sort_order = c1 DESC LAST + List output = + sql( + "CALL %s.system.rewrite_data_files(table => '%s', " + + "strategy => 'sort', sort_order => 'c1 DESC NULLS LAST')", + catalogName, tableIdent); + + assertEquals( + "Action should rewrite 10 data files and add 1 data files", + row(10, 1), + Arrays.copyOf(output.get(0), 2)); + // verify rewritten bytes separately + assertThat(output.get(0)).hasSize(4); + assertThat(output.get(0)[2]) + .isInstanceOf(Long.class) + .isEqualTo(Long.valueOf(snapshotSummary().get(SnapshotSummary.REMOVED_FILE_SIZE_PROP))); + + List actualRecords = currentData(); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + } + + @TestTemplate + public void testRewriteDataFilesWithSortStrategyAndMultipleShufflePartitionsPerFile() { + createTable(); + insertData(10 /* file count */); + + List output = + sql( + "CALL %s.system.rewrite_data_files(" + + " table => '%s', " + + " strategy => 'sort', " + + " sort_order => 'c1', " + + " options => map('shuffle-partitions-per-file', '2'))", + catalogName, tableIdent); + + assertEquals( + "Action should rewrite 10 data files and add 1 data files", + row(10, 1), + Arrays.copyOf(output.get(0), 2)); + + // as there is only one small output file, validate the query ordering (it will not change) + ImmutableList expectedRows = + ImmutableList.of( + row(1, "foo", null), + row(1, "foo", null), + row(1, "foo", null), + row(1, "foo", null), + row(1, "foo", null), + row(2, "bar", null), + row(2, "bar", null), + row(2, "bar", null), + row(2, "bar", null), + row(2, "bar", null)); + assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s", tableName)); + } + + @TestTemplate + public void testRewriteDataFilesWithZOrder() { + createTable(); + // create 10 files under non-partitioned table + insertData(10); + + // set z_order = c1,c2 + List output = + sql( + "CALL %s.system.rewrite_data_files(table => '%s', " + + "strategy => 'sort', sort_order => 'zorder(c1,c2)')", + catalogName, tableIdent); + + assertEquals( + "Action should rewrite 10 data files and add 1 data files", + row(10, 1), + Arrays.copyOf(output.get(0), 2)); + // verify rewritten bytes separately + assertThat(output.get(0)).hasSize(4); + assertThat(output.get(0)[2]) + .isInstanceOf(Long.class) + .isEqualTo(Long.valueOf(snapshotSummary().get(SnapshotSummary.REMOVED_FILE_SIZE_PROP))); + + // Due to Z_order, the data written will be in the below order. + // As there is only one small output file, we can validate the query ordering (as it will not + // change). + ImmutableList expectedRows = + ImmutableList.of( + row(2, "bar", null), + row(2, "bar", null), + row(2, "bar", null), + row(2, "bar", null), + row(2, "bar", null), + row(1, "foo", null), + row(1, "foo", null), + row(1, "foo", null), + row(1, "foo", null), + row(1, "foo", null)); + assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s", tableName)); + } + + @TestTemplate + public void testRewriteDataFilesWithZOrderNullBinaryColumn() { + sql("CREATE TABLE %s (c1 int, c2 string, c3 binary) USING iceberg", tableName); + + for (int i = 0; i < 5; i++) { + sql("INSERT INTO %s values (1, 'foo', null), (2, 'bar', null)", tableName); + } + + List output = + sql( + "CALL %s.system.rewrite_data_files(table => '%s', " + + "strategy => 'sort', sort_order => 'zorder(c2,c3)')", + catalogName, tableIdent); + + assertEquals( + "Action should rewrite 10 data files and add 1 data files", + row(10, 1), + Arrays.copyOf(output.get(0), 2)); + assertThat(output.get(0)).hasSize(4); + assertThat(snapshotSummary()) + .containsEntry(SnapshotSummary.REMOVED_FILE_SIZE_PROP, String.valueOf(output.get(0)[2])); + assertThat(sql("SELECT * FROM %s", tableName)) + .containsExactly( + row(2, "bar", null), + row(2, "bar", null), + row(2, "bar", null), + row(2, "bar", null), + row(2, "bar", null), + row(1, "foo", null), + row(1, "foo", null), + row(1, "foo", null), + row(1, "foo", null), + row(1, "foo", null)); + } + + @TestTemplate + public void testRewriteDataFilesWithZOrderAndMultipleShufflePartitionsPerFile() { + createTable(); + insertData(10 /* file count */); + + List output = + sql( + "CALL %s.system.rewrite_data_files(" + + " table => '%s', " + + "strategy => 'sort', " + + " sort_order => 'zorder(c1, c2)', " + + " options => map('shuffle-partitions-per-file', '2'))", + catalogName, tableIdent); + + assertEquals( + "Action should rewrite 10 data files and add 1 data files", + row(10, 1), + Arrays.copyOf(output.get(0), 2)); + + // due to z-ordering, the data will be written in the below order + // as there is only one small output file, validate the query ordering (it will not change) + ImmutableList expectedRows = + ImmutableList.of( + row(2, "bar", null), + row(2, "bar", null), + row(2, "bar", null), + row(2, "bar", null), + row(2, "bar", null), + row(1, "foo", null), + row(1, "foo", null), + row(1, "foo", null), + row(1, "foo", null), + row(1, "foo", null)); + assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s", tableName)); + } + + @TestTemplate + public void testRewriteDataFilesWithFilter() { + createTable(); + // create 10 files under non-partitioned table + insertData(10); + List expectedRecords = currentData(); + + // select only 5 files for compaction (files that may have c1 = 1) + List output = + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c1 = 1 and c2 is not null')", + catalogName, tableIdent); + + assertEquals( + "Action should rewrite 5 data files (containing c1 = 1) and add 1 data files", + row(5, 1), + Arrays.copyOf(output.get(0), 2)); + // verify rewritten bytes separately + assertThat(output.get(0)).hasSize(4); + assertThat(output.get(0)[2]) + .isInstanceOf(Long.class) + .isEqualTo(Long.valueOf(snapshotSummary().get(SnapshotSummary.REMOVED_FILE_SIZE_PROP))); + + List actualRecords = currentData(); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + } + + @TestTemplate + public void testRewriteDataFilesWithDeterministicTrueFilter() { + createTable(); + // create 10 files under non-partitioned table + insertData(10); + List expectedRecords = currentData(); + // select all 10 files for compaction + List output = + sql( + "CALL %s.system.rewrite_data_files(table => '%s', where => '1=1')", + catalogName, tableIdent); + assertEquals( + "Action should rewrite 10 data files and add 1 data files", + row(10, 1), + Arrays.copyOf(output.get(0), 2)); + // verify rewritten bytes separately + assertThat(output.get(0)).hasSize(4); + assertThat(output.get(0)[2]) + .isInstanceOf(Long.class) + .isEqualTo(Long.valueOf(snapshotSummary().get(SnapshotSummary.REMOVED_FILE_SIZE_PROP))); + List actualRecords = currentData(); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + } + + @TestTemplate + public void testRewriteDataFilesWithDeterministicFalseFilter() { + createTable(); + // create 10 files under non-partitioned table + insertData(10); + List expectedRecords = currentData(); + // select no files for compaction + List output = + sql( + "CALL %s.system.rewrite_data_files(table => '%s', where => '0=1')", + catalogName, tableIdent); + assertEquals( + "Action should rewrite 0 data files and add 0 data files", + row(0, 0), + Arrays.copyOf(output.get(0), 2)); + List actualRecords = currentData(); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + } + + @TestTemplate + public void testRewriteDataFilesWithFilterOnPartitionTable() { + createPartitionTable(); + // create 5 files for each partition (c2 = 'foo' and c2 = 'bar') + insertData(10); + List expectedRecords = currentData(); + + // select only 5 files for compaction (files in the partition c2 = 'bar') + List output = + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + " where => 'c2 = \"bar\"')", + catalogName, tableIdent); + + assertEquals( + "Action should rewrite 5 data files from single matching partition" + + "(containing c2 = bar) and add 1 data files", + row(5, 1), + Arrays.copyOf(output.get(0), 2)); + // verify rewritten bytes separately + assertThat(output.get(0)).hasSize(4); + assertThat(output.get(0)[2]) + .isInstanceOf(Long.class) + .isEqualTo(Long.valueOf(snapshotSummary().get(SnapshotSummary.REMOVED_FILE_SIZE_PROP))); + + List actualRecords = currentData(); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + } + + @TestTemplate + public void testRewriteDataFilesWithFilterOnOnBucketExpression() { + // currently spark session catalog only resolve to v1 functions instead of desired v2 functions + // https://github.com/apache/spark/blob/branch-3.4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala#L2070-L2083 + assumeThat(catalogName).isNotEqualTo(SparkCatalogConfig.SPARK.catalogName()); + createBucketPartitionTable(); + // create 5 files for each partition (c2 = 'foo' and c2 = 'bar') + insertData(10); + List expectedRecords = currentData(); + + // select only 5 files for compaction (files in the partition c2 = 'bar') + List output = + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + + " where => '%s.system.bucket(2, c2) = 0')", + catalogName, tableIdent, catalogName); + + assertEquals( + "Action should rewrite 5 data files from single matching partition" + + "(containing bucket(c2) = 0) and add 1 data files", + row(5, 1), + row(output.get(0)[0], output.get(0)[1])); + // verify rewritten bytes separately + assertThat(output.get(0)).hasSize(4); + assertThat(output.get(0)[2]) + .isInstanceOf(Long.class) + .isEqualTo(Long.valueOf(snapshotSummary().get(SnapshotSummary.REMOVED_FILE_SIZE_PROP))); + + List actualRecords = currentData(); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + } + + @TestTemplate + public void testRewriteDataFilesWithInFilterOnPartitionTable() { + createPartitionTable(); + // create 5 files for each partition (c2 = 'foo' and c2 = 'bar') + insertData(10); + List expectedRecords = currentData(); + + // select only 5 files for compaction (files in the partition c2 in ('bar')) + List output = + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + " where => 'c2 in (\"bar\")')", + catalogName, tableIdent); + + assertEquals( + "Action should rewrite 5 data files from single matching partition" + + "(containing c2 = bar) and add 1 data files", + row(5, 1), + Arrays.copyOf(output.get(0), 2)); + // verify rewritten bytes separately + assertThat(output.get(0)).hasSize(4); + assertThat(output.get(0)[2]) + .isInstanceOf(Long.class) + .isEqualTo(Long.valueOf(snapshotSummary().get(SnapshotSummary.REMOVED_FILE_SIZE_PROP))); + + List actualRecords = currentData(); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + } + + @TestTemplate + public void testRewriteDataFilesWithAllPossibleFilters() { + createPartitionTable(); + // create 5 files for each partition (c2 = 'foo' and c2 = 'bar') + insertData(10); + + // Pass the literal value which is not present in the data files. + // So that parsing can be tested on a same dataset without actually compacting the files. + + // EqualTo + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + " where => 'c1 = 3')", + catalogName, tableIdent); + // GreaterThan + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + " where => 'c1 > 3')", + catalogName, tableIdent); + // GreaterThanOrEqual + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + " where => 'c1 >= 3')", + catalogName, tableIdent); + // LessThan + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + " where => 'c1 < 0')", + catalogName, tableIdent); + // LessThanOrEqual + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + " where => 'c1 <= 0')", + catalogName, tableIdent); + // In + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + " where => 'c1 in (3,4,5)')", + catalogName, tableIdent); + // IsNull + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + " where => 'c1 is null')", + catalogName, tableIdent); + // IsNotNull + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + " where => 'c3 is not null')", + catalogName, tableIdent); + // And + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + " where => 'c1 = 3 and c2 = \"bar\"')", + catalogName, tableIdent); + // Or + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + " where => 'c1 = 3 or c1 = 5')", + catalogName, tableIdent); + // Not + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + " where => 'c1 not in (1,2)')", + catalogName, tableIdent); + // StringStartsWith + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + " where => 'c2 like \"%s\"')", + catalogName, tableIdent, "car%"); + // TODO: Enable when org.apache.iceberg.spark.SparkFilters have implementations for + // StringEndsWith & StringContains + // StringEndsWith + // sql("CALL %s.system.rewrite_data_files(table => '%s'," + + // " where => 'c2 like \"%s\"')", catalogName, tableIdent, "%car"); + // StringContains + // sql("CALL %s.system.rewrite_data_files(table => '%s'," + + // " where => 'c2 like \"%s\"')", catalogName, tableIdent, "%car%"); + } + + @TestTemplate + public void testRewriteDataFilesWithPossibleV2Filters() { + // currently spark session catalog only resolve to v1 functions instead of desired v2 functions + // https://github.com/apache/spark/blob/branch-3.4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala#L2070-L2083 + assumeThat(catalogName).isNotEqualTo(SparkCatalogConfig.SPARK.catalogName()); + + SystemFunctionPushDownHelper.createPartitionedTable(spark, tableName, "id"); + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + + " where => '%s.system.bucket(2, data) >= 0')", + catalogName, tableIdent, catalogName); + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + + " where => '%s.system.truncate(4, id) >= 1')", + catalogName, tableIdent, catalogName); + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + + " where => '%s.system.years(ts) >= 1')", + catalogName, tableIdent, catalogName); + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + + " where => '%s.system.months(ts) >= 1')", + catalogName, tableIdent, catalogName); + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + + " where => '%s.system.days(ts) >= date(\"2023-01-01\")')", + catalogName, tableIdent, catalogName); + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + + " where => '%s.system.hours(ts) >= 1')", + catalogName, tableIdent, catalogName); + } + + @TestTemplate + public void testRewriteDataFilesWithInvalidInputs() { + createTable(); + // create 2 files under non-partitioned table + insertData(2); + + // Test for invalid strategy + assertThatThrownBy( + () -> + sql( + "CALL %s.system.rewrite_data_files(table => '%s', options => map('min-input-files','2'), " + + "strategy => 'temp')", + catalogName, tableIdent)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("unsupported strategy: temp. Only binpack or sort is supported"); + + // Test for sort_order with binpack strategy + assertThatThrownBy( + () -> + sql( + "CALL %s.system.rewrite_data_files(table => '%s', strategy => 'binpack', " + + "sort_order => 'c1 ASC NULLS FIRST')", + catalogName, tableIdent)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Must use only one rewriter type (bin-pack, sort, zorder)"); + + // Test for sort strategy without any (default/user defined) sort_order + assertThatThrownBy( + () -> + sql( + "CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort')", + catalogName, tableIdent)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageStartingWith("Cannot sort data without a valid sort order"); + + // Test for sort_order with invalid null order + assertThatThrownBy( + () -> + sql( + "CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort', " + + "sort_order => 'c1 ASC none')", + catalogName, tableIdent)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Unable to parse sortOrder: c1 ASC none"); + + // Test for sort_order with invalid sort direction + assertThatThrownBy( + () -> + sql( + "CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort', " + + "sort_order => 'c1 none NULLS FIRST')", + catalogName, tableIdent)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Unable to parse sortOrder: c1 none NULLS FIRST"); + + // Test for sort_order with invalid column name + assertThatThrownBy( + () -> + sql( + "CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort', " + + "sort_order => 'col1 DESC NULLS FIRST')", + catalogName, tableIdent)) + .isInstanceOf(ValidationException.class) + .hasMessageStartingWith("Cannot find field 'col1' in struct:"); + + // Test with invalid filter column col1 + assertThatThrownBy( + () -> + sql( + "CALL %s.system.rewrite_data_files(table => '%s', " + "where => 'col1 = 3')", + catalogName, tableIdent)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot parse predicates in where option: col1 = 3"); + + // Test for z_order with invalid column name + assertThatThrownBy( + () -> + sql( + "CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort', " + + "sort_order => 'zorder(col1)')", + catalogName, tableIdent)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Cannot find column 'col1' in table schema (case sensitive = false): " + + "struct<1: c1: optional int, 2: c2: optional string, 3: c3: optional string>"); + + // Test for z_order with sort_order + assertThatThrownBy( + () -> + sql( + "CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort', " + + "sort_order => 'c1,zorder(c2,c3)')", + catalogName, tableIdent)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Cannot mix identity sort columns and a Zorder sort expression:" + " c1,zorder(c2,c3)"); + } + + @TestTemplate + public void testInvalidCasesForRewriteDataFiles() { + assertThatThrownBy( + () -> sql("CALL %s.system.rewrite_data_files('n', table => 't')", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessage("Named and positional arguments cannot be mixed"); + + assertThatThrownBy(() -> sql("CALL %s.custom.rewrite_data_files('n', 't')", catalogName)) + .isInstanceOf(NoSuchProcedureException.class) + .hasMessage("Procedure custom.rewrite_data_files not found"); + + assertThatThrownBy(() -> sql("CALL %s.system.rewrite_data_files()", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Missing required parameters: [table]"); + + assertThatThrownBy( + () -> sql("CALL %s.system.rewrite_data_files(table => 't', table => 't')", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessageEndingWith( + "Could not build name to arg map: Duplicate procedure argument: table"); + + assertThatThrownBy(() -> sql("CALL %s.system.rewrite_data_files('')", catalogName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot handle an empty identifier for parameter 'table'"); + } + + @TestTemplate + public void testBinPackTableWithSpecialChars() { + assumeThat(catalogName).isEqualTo(SparkCatalogConfig.HADOOP.catalogName()); + + TableIdentifier identifier = + TableIdentifier.of("default", QUOTED_SPECIAL_CHARS_TABLE_NAME.replaceAll("`", "")); + sql( + "CREATE TABLE %s (c1 int, c2 string, c3 string) USING iceberg", + tableName(QUOTED_SPECIAL_CHARS_TABLE_NAME)); + + insertData(tableName(QUOTED_SPECIAL_CHARS_TABLE_NAME), 10); + + List expectedRecords = currentData(tableName(QUOTED_SPECIAL_CHARS_TABLE_NAME)); + + List output = + sql( + "CALL %s.system.rewrite_data_files(table => '%s', where => 'c2 is not null')", + catalogName, tableName(QUOTED_SPECIAL_CHARS_TABLE_NAME)); + + assertEquals( + "Action should rewrite 10 data files and add 1 data file", + row(10, 1), + Arrays.copyOf(output.get(0), 2)); + // verify rewritten bytes separately + assertThat(output.get(0)).hasSize(4); + assertThat(output.get(0)[2]) + .isEqualTo( + Long.valueOf(snapshotSummary(identifier).get(SnapshotSummary.REMOVED_FILE_SIZE_PROP))); + + List actualRecords = currentData(tableName(QUOTED_SPECIAL_CHARS_TABLE_NAME)); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + + assertThat(SparkTableCache.get().size()).as("Table cache must be empty").isEqualTo(0); + } + + @TestTemplate + public void testSortTableWithSpecialChars() { + assumeThat(catalogName).isEqualTo(SparkCatalogConfig.HADOOP.catalogName()); + + TableIdentifier identifier = + TableIdentifier.of("default", QUOTED_SPECIAL_CHARS_TABLE_NAME.replaceAll("`", "")); + sql( + "CREATE TABLE %s (c1 int, c2 string, c3 string) USING iceberg", + tableName(QUOTED_SPECIAL_CHARS_TABLE_NAME)); + + insertData(tableName(QUOTED_SPECIAL_CHARS_TABLE_NAME), 10); + + List expectedRecords = currentData(tableName(QUOTED_SPECIAL_CHARS_TABLE_NAME)); + + List output = + sql( + "CALL %s.system.rewrite_data_files(" + + " table => '%s'," + + " strategy => 'sort'," + + " sort_order => 'c1'," + + " where => 'c2 is not null')", + catalogName, tableName(QUOTED_SPECIAL_CHARS_TABLE_NAME)); + + assertEquals( + "Action should rewrite 10 data files and add 1 data file", + row(10, 1), + Arrays.copyOf(output.get(0), 2)); + // verify rewritten bytes separately + assertThat(output.get(0)).hasSize(4); + assertThat(output.get(0)[2]) + .isInstanceOf(Long.class) + .isEqualTo( + Long.valueOf(snapshotSummary(identifier).get(SnapshotSummary.REMOVED_FILE_SIZE_PROP))); + + List actualRecords = currentData(tableName(QUOTED_SPECIAL_CHARS_TABLE_NAME)); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + + assertThat(SparkTableCache.get().size()).as("Table cache must be empty").isEqualTo(0); + } + + @TestTemplate + public void testZOrderTableWithSpecialChars() { + assumeThat(catalogName).isEqualTo(SparkCatalogConfig.HADOOP.catalogName()); + + TableIdentifier identifier = + TableIdentifier.of("default", QUOTED_SPECIAL_CHARS_TABLE_NAME.replaceAll("`", "")); + sql( + "CREATE TABLE %s (c1 int, c2 string, c3 string) USING iceberg", + tableName(QUOTED_SPECIAL_CHARS_TABLE_NAME)); + + insertData(tableName(QUOTED_SPECIAL_CHARS_TABLE_NAME), 10); + + List expectedRecords = currentData(tableName(QUOTED_SPECIAL_CHARS_TABLE_NAME)); + + List output = + sql( + "CALL %s.system.rewrite_data_files(" + + " table => '%s'," + + " strategy => 'sort'," + + " sort_order => 'zorder(c1, c2)'," + + " where => 'c2 is not null')", + catalogName, tableName(QUOTED_SPECIAL_CHARS_TABLE_NAME)); + + assertEquals( + "Action should rewrite 10 data files and add 1 data file", + row(10, 1), + Arrays.copyOf(output.get(0), 2)); + // verify rewritten bytes separately + assertThat(output.get(0)).hasSize(4); + assertThat(output.get(0)[2]) + .isInstanceOf(Long.class) + .isEqualTo( + Long.valueOf(snapshotSummary(identifier).get(SnapshotSummary.REMOVED_FILE_SIZE_PROP))); + + List actualRecords = currentData(tableName(QUOTED_SPECIAL_CHARS_TABLE_NAME)); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + + assertThat(SparkTableCache.get().size()).as("Table cache must be empty").isEqualTo(0); + } + + @TestTemplate + public void testDefaultSortOrder() { + createTable(); + // add a default sort order for a table + sql("ALTER TABLE %s WRITE ORDERED BY c2", tableName); + + // this creates 2 files under non-partitioned table due to sort order. + insertData(10); + List expectedRecords = currentData(); + + // When the strategy is set to 'sort' but the sort order is not specified, + // use table's default sort order. + List output = + sql( + "CALL %s.system.rewrite_data_files(table => '%s', " + + "strategy => 'sort', " + + "options => map('min-input-files','2'))", + catalogName, tableIdent); + + assertEquals( + "Action should rewrite 2 data files and add 1 data files", + row(2, 1), + Arrays.copyOf(output.get(0), 2)); + // verify rewritten bytes separately + assertThat(output.get(0)).hasSize(4); + assertThat(output.get(0)[2]) + .isInstanceOf(Long.class) + .isEqualTo(Long.valueOf(snapshotSummary().get(SnapshotSummary.REMOVED_FILE_SIZE_PROP))); + + List actualRecords = currentData(); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + } + + @TestTemplate + public void testRewriteWithUntranslatedOrUnconvertedFilter() { + createTable(); + assertThatThrownBy( + () -> + sql( + "CALL %s.system.rewrite_data_files(table => '%s', where => 'substr(encode(c2, \"utf-8\"), 2) = \"fo\"')", + catalogName, tableIdent)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Cannot translate Spark expression"); + + assertThatThrownBy( + () -> + sql( + "CALL %s.system.rewrite_data_files(table => '%s', where => 'substr(c2, 2) = \"fo\"')", + catalogName, tableIdent)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Cannot convert Spark filter"); + } + + @TestTemplate + public void testRewriteDataFilesSummary() { + createTable(); + // create 10 files under non-partitioned table + insertData(10); + sql("CALL %s.system.rewrite_data_files(table => '%s')", catalogName, tableIdent); + + Map summary = snapshotSummary(); + assertThat(summary) + .containsKey(CatalogProperties.APP_ID) + .containsEntry(EnvironmentContext.ENGINE_NAME, "spark") + .hasEntrySatisfying( + EnvironmentContext.ENGINE_VERSION, v -> assertThat(v).startsWith("4.0")); + } + + private void createTable() { + sql("CREATE TABLE %s (c1 int, c2 string, c3 string) USING iceberg", tableName); + } + + private void createPartitionTable() { + sql( + "CREATE TABLE %s (c1 int, c2 string, c3 string) " + + "USING iceberg " + + "PARTITIONED BY (c2) " + + "TBLPROPERTIES ('%s' '%s')", + tableName, + TableProperties.WRITE_DISTRIBUTION_MODE, + TableProperties.WRITE_DISTRIBUTION_MODE_NONE); + } + + private void createBucketPartitionTable() { + sql( + "CREATE TABLE %s (c1 int, c2 string, c3 string) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, c2)) " + + "TBLPROPERTIES ('%s' '%s')", + tableName, + TableProperties.WRITE_DISTRIBUTION_MODE, + TableProperties.WRITE_DISTRIBUTION_MODE_NONE); + } + + private void insertData(int filesCount) { + insertData(tableName, filesCount); + } + + private void insertData(String table, int filesCount) { + ThreeColumnRecord record1 = new ThreeColumnRecord(1, "foo", null); + ThreeColumnRecord record2 = new ThreeColumnRecord(2, "bar", null); + + List records = Lists.newArrayList(); + IntStream.range(0, filesCount / 2) + .forEach( + i -> { + records.add(record1); + records.add(record2); + }); + + Dataset df = + spark.createDataFrame(records, ThreeColumnRecord.class).repartition(filesCount); + try { + df.writeTo(table).append(); + } catch (org.apache.spark.sql.catalyst.analysis.NoSuchTableException e) { + throw new RuntimeException(e); + } + } + + private Map snapshotSummary() { + return snapshotSummary(tableIdent); + } + + private Map snapshotSummary(TableIdentifier tableIdentifier) { + return validationCatalog.loadTable(tableIdentifier).currentSnapshot().summary(); + } + + private List currentData() { + return currentData(tableName); + } + + private List currentData(String table) { + return rowsToJava(spark.sql("SELECT * FROM " + table + " order by c1, c2, c3").collectAsList()); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewriteManifestsProcedure.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewriteManifestsProcedure.java new file mode 100644 index 000000000000..5eebd9aeb711 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewriteManifestsProcedure.java @@ -0,0 +1,390 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.sql.Date; +import java.sql.Timestamp; +import java.util.List; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestRewriteManifestsProcedure extends ExtensionsTestBase { + + @AfterEach + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testRewriteManifestsInEmptyTable() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + List output = sql("CALL %s.system.rewrite_manifests('%s')", catalogName, tableIdent); + assertEquals("Procedure output must match", ImmutableList.of(row(0, 0)), output); + } + + @TestTemplate + public void testRewriteLargeManifests() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg PARTITIONED BY (data)", + tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a'), (2, 'b'), (3, 'c'), (4, 'd')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + assertThat(table.currentSnapshot().allManifests(table.io())) + .as("Must have 1 manifest") + .hasSize(1); + + sql("ALTER TABLE %s SET TBLPROPERTIES ('commit.manifest.target-size-bytes' '1')", tableName); + + List output = sql("CALL %s.system.rewrite_manifests('%s')", catalogName, tableIdent); + assertEquals("Procedure output must match", ImmutableList.of(row(1, 4)), output); + + table.refresh(); + + assertThat(table.currentSnapshot().allManifests(table.io())) + .as("Must have 4 manifests") + .hasSize(4); + } + + @TestTemplate + public void testRewriteManifestsNoOp() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg PARTITIONED BY (data)", + tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a'), (2, 'b'), (3, 'c'), (4, 'd')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + assertThat(table.currentSnapshot().allManifests(table.io())) + .as("Must have 1 manifest") + .hasSize(1); + + List output = sql("CALL %s.system.rewrite_manifests('%s')", catalogName, tableIdent); + // should not rewrite any manifests for no-op (output of rewrite is same as before and after) + assertEquals("Procedure output must match", ImmutableList.of(row(0, 0)), output); + + table.refresh(); + + assertThat(table.currentSnapshot().allManifests(table.io())) + .as("Must have 1 manifest") + .hasSize(1); + } + + @TestTemplate + public void testRewriteLargeManifestsOnDatePartitionedTableWithJava8APIEnabled() { + withSQLConf( + ImmutableMap.of("spark.sql.datetime.java8API.enabled", "true"), + () -> { + sql( + "CREATE TABLE %s (id INTEGER, name STRING, dept STRING, ts DATE) USING iceberg PARTITIONED BY (ts)", + tableName); + try { + spark + .createDataFrame( + ImmutableList.of( + RowFactory.create(1, "John Doe", "hr", Date.valueOf("2021-01-01")), + RowFactory.create(2, "Jane Doe", "hr", Date.valueOf("2021-01-02")), + RowFactory.create(3, "Matt Doe", "hr", Date.valueOf("2021-01-03")), + RowFactory.create(4, "Will Doe", "facilities", Date.valueOf("2021-01-04"))), + spark.table(tableName).schema()) + .writeTo(tableName) + .append(); + } catch (NoSuchTableException e) { + // not possible as we already created the table above. + throw new RuntimeException(e); + } + + Table table = validationCatalog.loadTable(tableIdent); + + assertThat(table.currentSnapshot().allManifests(table.io())) + .as("Must have 1 manifest") + .hasSize(1); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('commit.manifest.target-size-bytes' '1')", + tableName); + + List output = + sql("CALL %s.system.rewrite_manifests('%s')", catalogName, tableIdent); + assertEquals("Procedure output must match", ImmutableList.of(row(1, 4)), output); + + table.refresh(); + + assertThat(table.currentSnapshot().allManifests(table.io())) + .as("Must have 4 manifests") + .hasSize(4); + }); + } + + @TestTemplate + public void testRewriteLargeManifestsOnTimestampPartitionedTableWithJava8APIEnabled() { + withSQLConf( + ImmutableMap.of("spark.sql.datetime.java8API.enabled", "true"), + () -> { + sql( + "CREATE TABLE %s (id INTEGER, name STRING, dept STRING, ts TIMESTAMP) USING iceberg PARTITIONED BY (ts)", + tableName); + try { + spark + .createDataFrame( + ImmutableList.of( + RowFactory.create( + 1, "John Doe", "hr", Timestamp.valueOf("2021-01-01 00:00:00")), + RowFactory.create( + 2, "Jane Doe", "hr", Timestamp.valueOf("2021-01-02 00:00:00")), + RowFactory.create( + 3, "Matt Doe", "hr", Timestamp.valueOf("2021-01-03 00:00:00")), + RowFactory.create( + 4, "Will Doe", "facilities", Timestamp.valueOf("2021-01-04 00:00:00"))), + spark.table(tableName).schema()) + .writeTo(tableName) + .append(); + } catch (NoSuchTableException e) { + // not possible as we already created the table above. + throw new RuntimeException(e); + } + + Table table = validationCatalog.loadTable(tableIdent); + + assertThat(table.currentSnapshot().allManifests(table.io())) + .as("Must have 1 manifest") + .hasSize(1); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('commit.manifest.target-size-bytes' '1')", + tableName); + + List output = + sql("CALL %s.system.rewrite_manifests('%s')", catalogName, tableIdent); + assertEquals("Procedure output must match", ImmutableList.of(row(1, 4)), output); + + table.refresh(); + + assertThat(table.currentSnapshot().allManifests(table.io())) + .as("Must have 4 manifests") + .hasSize(4); + }); + } + + @TestTemplate + public void testRewriteSmallManifestsWithSnapshotIdInheritance() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg PARTITIONED BY (data)", + tableName); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' '%s')", + tableName, SNAPSHOT_ID_INHERITANCE_ENABLED, "true"); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + sql("INSERT INTO TABLE %s VALUES (3, 'c')", tableName); + sql("INSERT INTO TABLE %s VALUES (4, 'd')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + assertThat(table.currentSnapshot().allManifests(table.io())) + .as("Must have 4 manifests") + .hasSize(4); + + List output = + sql("CALL %s.system.rewrite_manifests(table => '%s')", catalogName, tableIdent); + assertEquals("Procedure output must match", ImmutableList.of(row(4, 1)), output); + + table.refresh(); + + assertThat(table.currentSnapshot().allManifests(table.io())) + .as("Must have 1 manifest") + .hasSize(1); + } + + @TestTemplate + public void testRewriteSmallManifestsWithoutCaching() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg PARTITIONED BY (data)", + tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + assertThat(table.currentSnapshot().allManifests(table.io())) + .as("Must have 2 manifest") + .hasSize(2); + + List output = + sql( + "CALL %s.system.rewrite_manifests(use_caching => false, table => '%s')", + catalogName, tableIdent); + assertEquals("Procedure output must match", ImmutableList.of(row(2, 1)), output); + + table.refresh(); + + assertThat(table.currentSnapshot().allManifests(table.io())) + .as("Must have 1 manifest") + .hasSize(1); + } + + @TestTemplate + public void testRewriteManifestsCaseInsensitiveArgs() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg PARTITIONED BY (data)", + tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + assertThat(table.currentSnapshot().allManifests(table.io())) + .as("Must have 2 manifests") + .hasSize(2); + + List output = + sql( + "CALL %s.system.rewrite_manifests(usE_cAcHiNg => false, tAbLe => '%s')", + catalogName, tableIdent); + assertEquals("Procedure output must match", ImmutableList.of(row(2, 1)), output); + + table.refresh(); + + assertThat(table.currentSnapshot().allManifests(table.io())) + .as("Must have 1 manifest") + .hasSize(1); + } + + @TestTemplate + public void testInvalidRewriteManifestsCases() { + assertThatThrownBy( + () -> sql("CALL %s.system.rewrite_manifests('n', table => 't')", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessage("Named and positional arguments cannot be mixed"); + + assertThatThrownBy(() -> sql("CALL %s.custom.rewrite_manifests('n', 't')", catalogName)) + .isInstanceOf(ParseException.class) + .satisfies( + exception -> { + ParseException parseException = (ParseException) exception; + assertThat(parseException.getErrorClass()).isEqualTo("PARSE_SYNTAX_ERROR"); + assertThat(parseException.getMessageParameters().get("error")).isEqualTo("'CALL'"); + }); + + assertThatThrownBy(() -> sql("CALL %s.system.rewrite_manifests()", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessage("Missing required parameters: [table]"); + + assertThatThrownBy(() -> sql("CALL %s.system.rewrite_manifests('n', 2.2)", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith("Wrong arg type for use_caching"); + + assertThatThrownBy( + () -> sql("CALL %s.system.rewrite_manifests(table => 't', tAbLe => 't')", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessage("Could not build name to arg map: Duplicate procedure argument: table"); + + assertThatThrownBy(() -> sql("CALL %s.system.rewrite_manifests('')", catalogName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot handle an empty identifier for argument table"); + } + + @TestTemplate + public void testReplacePartitionField() { + sql( + "CREATE TABLE %s (id int, ts timestamp, day_of_ts date) USING iceberg PARTITIONED BY (day_of_ts)", + tableName); + + sql("ALTER TABLE %s SET TBLPROPERTIES ('format-version' = '2')", tableName); + sql("ALTER TABLE %s REPLACE PARTITION FIELD day_of_ts WITH days(ts)\n", tableName); + sql( + "INSERT INTO %s VALUES (1, CAST('2022-01-01 10:00:00' AS TIMESTAMP), CAST('2022-01-01' AS DATE))", + tableName); + sql( + "INSERT INTO %s VALUES (2, CAST('2022-01-01 11:00:00' AS TIMESTAMP), CAST('2022-01-01' AS DATE))", + tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of( + row(1, Timestamp.valueOf("2022-01-01 10:00:00"), Date.valueOf("2022-01-01")), + row(2, Timestamp.valueOf("2022-01-01 11:00:00"), Date.valueOf("2022-01-01"))), + sql("SELECT * FROM %s WHERE ts < current_timestamp() order by 1 asc", tableName)); + + List output = + sql("CALL %s.system.rewrite_manifests(table => '%s')", catalogName, tableName); + assertEquals("Procedure output must match", ImmutableList.of(row(2, 1)), output); + + assertEquals( + "Should have expected rows", + ImmutableList.of( + row(1, Timestamp.valueOf("2022-01-01 10:00:00"), Date.valueOf("2022-01-01")), + row(2, Timestamp.valueOf("2022-01-01 11:00:00"), Date.valueOf("2022-01-01"))), + sql("SELECT * FROM %s WHERE ts < current_timestamp() order by 1 asc", tableName)); + } + + @TestTemplate + public void testWriteManifestWithSpecId() { + sql( + "CREATE TABLE %s (id int, dt string, hr string) USING iceberg PARTITIONED BY (dt)", + tableName); + sql("ALTER TABLE %s SET TBLPROPERTIES ('commit.manifest-merge.enabled' = 'false')", tableName); + + sql("INSERT INTO %s VALUES (1, '2024-01-01', '00')", tableName); + sql("INSERT INTO %s VALUES (2, '2024-01-01', '00')", tableName); + assertEquals( + "Should have 2 manifests and their partition spec id should be 0", + ImmutableList.of(row(0), row(0)), + sql("SELECT partition_spec_id FROM %s.manifests order by 1 asc", tableName)); + + sql("ALTER TABLE %s ADD PARTITION FIELD hr", tableName); + sql("INSERT INTO %s VALUES (3, '2024-01-01', '00')", tableName); + assertEquals( + "Should have 3 manifests and their partition spec id should be 0 and 1", + ImmutableList.of(row(0), row(0), row(1)), + sql("SELECT partition_spec_id FROM %s.manifests order by 1 asc", tableName)); + + List output = sql("CALL %s.system.rewrite_manifests('%s')", catalogName, tableIdent); + assertEquals("Nothing should be rewritten", ImmutableList.of(row(0, 0)), output); + + output = + sql( + "CALL %s.system.rewrite_manifests(table => '%s', spec_id => 0)", + catalogName, tableIdent); + assertEquals("There should be 2 manifests rewriten", ImmutableList.of(row(2, 1)), output); + assertEquals( + "Should have 2 manifests and their partition spec id should be 0 and 1", + ImmutableList.of(row(0), row(1)), + sql("SELECT partition_spec_id FROM %s.manifests order by 1 asc", tableName)); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewritePositionDeleteFiles.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewritePositionDeleteFiles.java new file mode 100644 index 000000000000..f7329e841800 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewritePositionDeleteFiles.java @@ -0,0 +1,418 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT_DEFAULT; +import static org.apache.spark.sql.functions.expr; +import static org.apache.spark.sql.functions.lit; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.Closeable; +import java.io.IOException; +import java.math.BigDecimal; +import java.sql.Date; +import java.sql.Timestamp; +import java.time.LocalDateTime; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.function.Function; +import java.util.stream.Collectors; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Files; +import org.apache.iceberg.MetadataTableType; +import org.apache.iceberg.MetadataTableUtils; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.PositionDeletesScanTask; +import org.apache.iceberg.RowDelta; +import org.apache.iceberg.ScanTask; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.RewritePositionDeleteFiles.FileGroupRewriteResult; +import org.apache.iceberg.actions.RewritePositionDeleteFiles.Result; +import org.apache.iceberg.actions.SizeBasedFileRewriter; +import org.apache.iceberg.data.GenericAppenderFactory; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.deletes.PositionDelete; +import org.apache.iceberg.deletes.PositionDeleteWriter; +import org.apache.iceberg.encryption.EncryptedFiles; +import org.apache.iceberg.encryption.EncryptedOutputFile; +import org.apache.iceberg.encryption.EncryptionKeyMetadata; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppenderFactory; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.actions.SparkActions; +import org.apache.iceberg.util.Pair; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.StructType; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestRewritePositionDeleteFiles extends ExtensionsTestBase { + + private static final Map CATALOG_PROPS = + ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "cache-enabled", "false"); + + private static final String PARTITION_COL = "partition_col"; + private static final int NUM_DATA_FILES = 5; + private static final int ROWS_PER_DATA_FILE = 100; + private static final int DELETE_FILES_PER_PARTITION = 2; + private static final int DELETE_FILE_SIZE = 10; + + @Parameters(name = "formatVersion = {0}, catalogName = {1}, implementation = {2}, config = {3}") + public static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.HIVE.catalogName(), + SparkCatalogConfig.HIVE.implementation(), + CATALOG_PROPS + } + }; + } + + @AfterEach + public void cleanup() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testDatePartition() throws Exception { + createTable("date"); + Date baseDate = Date.valueOf("2023-01-01"); + insertData(i -> Date.valueOf(baseDate.toLocalDate().plusDays(i))); + testDanglingDelete(); + } + + @TestTemplate + public void testBooleanPartition() throws Exception { + createTable("boolean"); + insertData(i -> i % 2 == 0, 2); + testDanglingDelete(2); + } + + @TestTemplate + public void testTimestampPartition() throws Exception { + createTable("timestamp"); + Timestamp baseTimestamp = Timestamp.valueOf("2023-01-01 15:30:00"); + insertData(i -> Timestamp.valueOf(baseTimestamp.toLocalDateTime().plusDays(i))); + testDanglingDelete(); + } + + @TestTemplate + public void testTimestampNtz() throws Exception { + createTable("timestamp_ntz"); + LocalDateTime baseTimestamp = Timestamp.valueOf("2023-01-01 15:30:00").toLocalDateTime(); + insertData(baseTimestamp::plusDays); + testDanglingDelete(); + } + + @TestTemplate + public void testBytePartition() throws Exception { + createTable("byte"); + insertData(i -> i); + testDanglingDelete(); + } + + @TestTemplate + public void testDecimalPartition() throws Exception { + createTable("decimal(18, 10)"); + BigDecimal baseDecimal = new BigDecimal("1.0"); + insertData(i -> baseDecimal.add(new BigDecimal(i))); + testDanglingDelete(); + } + + @TestTemplate + public void testBinaryPartition() throws Exception { + createTable("binary"); + insertData(i -> java.nio.ByteBuffer.allocate(4).putInt(i).array()); + testDanglingDelete(); + } + + @TestTemplate + public void testCharPartition() throws Exception { + createTable("char(10)"); + insertData(Object::toString); + testDanglingDelete(); + } + + @TestTemplate + public void testVarcharPartition() throws Exception { + createTable("varchar(10)"); + insertData(Object::toString); + testDanglingDelete(); + } + + @TestTemplate + public void testIntPartition() throws Exception { + createTable("int"); + insertData(i -> i); + testDanglingDelete(); + } + + @TestTemplate + public void testDaysPartitionTransform() throws Exception { + createTable("timestamp", PARTITION_COL, String.format("days(%s)", PARTITION_COL)); + Timestamp baseTimestamp = Timestamp.valueOf("2023-01-01 15:30:00"); + insertData(i -> Timestamp.valueOf(baseTimestamp.toLocalDateTime().plusDays(i))); + testDanglingDelete(); + } + + @TestTemplate + public void testNullTransform() throws Exception { + createTable("int"); + insertData(i -> i == 0 ? null : 1, 2); + testDanglingDelete(2); + } + + @TestTemplate + public void testPartitionColWithDot() throws Exception { + String partitionColWithDot = "`partition.col`"; + createTable("int", partitionColWithDot, partitionColWithDot); + insertData(partitionColWithDot, i -> i, NUM_DATA_FILES); + testDanglingDelete(partitionColWithDot, NUM_DATA_FILES); + } + + private void testDanglingDelete() throws Exception { + testDanglingDelete(NUM_DATA_FILES); + } + + private void testDanglingDelete(int numDataFiles) throws Exception { + testDanglingDelete(PARTITION_COL, numDataFiles); + } + + private void testDanglingDelete(String partitionCol, int numDataFiles) throws Exception { + Table table = Spark3Util.loadIcebergTable(spark, tableName); + + List dataFiles = dataFiles(table); + assertThat(dataFiles).hasSize(numDataFiles); + + SparkActions.get(spark) + .rewriteDataFiles(table) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .execute(); + + // write dangling delete files for 'old data files' + writePosDeletesForFiles(table, dataFiles); + List deleteFiles = deleteFiles(table); + assertThat(deleteFiles).hasSize(numDataFiles * DELETE_FILES_PER_PARTITION); + + List expectedRecords = records(tableName, partitionCol); + + Result result = + SparkActions.get(spark) + .rewritePositionDeletes(table) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .execute(); + + List newDeleteFiles = deleteFiles(table); + assertThat(newDeleteFiles).as("Remaining dangling deletes").isEmpty(); + checkResult(result, deleteFiles, Lists.newArrayList(), numDataFiles); + + List actualRecords = records(tableName, partitionCol); + assertEquals("Rows must match", expectedRecords, actualRecords); + } + + private void createTable(String partitionType) { + createTable(partitionType, PARTITION_COL, PARTITION_COL); + } + + private void createTable(String partitionType, String partitionCol, String partitionTransform) { + sql( + "CREATE TABLE %s (id long, %s %s, c1 string, c2 string) " + + "USING iceberg " + + "PARTITIONED BY (%s) " + + "TBLPROPERTIES('format-version'='2')", + tableName, partitionCol, partitionType, partitionTransform); + } + + private void insertData(Function partitionValueFunction) throws Exception { + insertData(partitionValueFunction, NUM_DATA_FILES); + } + + private void insertData(Function partitionValueFunction, int numDataFiles) + throws Exception { + insertData(PARTITION_COL, partitionValueFunction, numDataFiles); + } + + private void insertData( + String partitionCol, Function partitionValue, int numDataFiles) throws Exception { + for (int i = 0; i < numDataFiles; i++) { + Dataset df = + spark + .range(0, ROWS_PER_DATA_FILE) + .withColumn(partitionCol, lit(partitionValue.apply(i))) + .withColumn("c1", expr("CAST(id AS STRING)")) + .withColumn("c2", expr("CAST(id AS STRING)")); + appendAsFile(df); + } + } + + private void appendAsFile(Dataset df) throws Exception { + // ensure the schema is precise + StructType sparkSchema = spark.table(tableName).schema(); + spark.createDataFrame(df.rdd(), sparkSchema).coalesce(1).writeTo(tableName).append(); + } + + private void writePosDeletesForFiles(Table table, List files) throws IOException { + + Map> filesByPartition = + files.stream().collect(Collectors.groupingBy(ContentFile::partition)); + List deleteFiles = + Lists.newArrayListWithCapacity(DELETE_FILES_PER_PARTITION * filesByPartition.size()); + + for (Map.Entry> filesByPartitionEntry : + filesByPartition.entrySet()) { + + StructLike partition = filesByPartitionEntry.getKey(); + List partitionFiles = filesByPartitionEntry.getValue(); + + int deletesForPartition = partitionFiles.size() * DELETE_FILE_SIZE; + assertThat(deletesForPartition % DELETE_FILE_SIZE) + .as("Number of delete files per partition modulo number of data files in this partition") + .isEqualTo(0); + int deleteFileSize = deletesForPartition / DELETE_FILES_PER_PARTITION; + + int counter = 0; + List> deletes = Lists.newArrayList(); + for (DataFile partitionFile : partitionFiles) { + for (int deletePos = 0; deletePos < DELETE_FILE_SIZE; deletePos++) { + deletes.add(Pair.of(partitionFile.path(), (long) deletePos)); + counter++; + if (counter == deleteFileSize) { + // Dump to file and reset variables + OutputFile output = + Files.localOutput(temp.resolve(UUID.randomUUID().toString()).toFile()); + deleteFiles.add(writeDeleteFile(table, output, partition, deletes)); + counter = 0; + deletes.clear(); + } + } + } + } + + RowDelta rowDelta = table.newRowDelta(); + deleteFiles.forEach(rowDelta::addDeletes); + rowDelta.commit(); + } + + private DeleteFile writeDeleteFile( + Table table, OutputFile out, StructLike partition, List> deletes) + throws IOException { + FileFormat format = defaultFormat(table.properties()); + FileAppenderFactory factory = new GenericAppenderFactory(table.schema(), table.spec()); + + PositionDeleteWriter writer = + factory.newPosDeleteWriter(encrypt(out), format, partition); + PositionDelete posDelete = PositionDelete.create(); + try (Closeable toClose = writer) { + for (Pair delete : deletes) { + writer.write(posDelete.set(delete.first(), delete.second(), null)); + } + } + + return writer.toDeleteFile(); + } + + private static EncryptedOutputFile encrypt(OutputFile out) { + return EncryptedFiles.encryptedOutput(out, EncryptionKeyMetadata.EMPTY); + } + + private static FileFormat defaultFormat(Map properties) { + String formatString = properties.getOrDefault(DEFAULT_FILE_FORMAT, DEFAULT_FILE_FORMAT_DEFAULT); + return FileFormat.fromString(formatString); + } + + private List records(String table, String partitionCol) { + return rowsToJava( + spark.read().format("iceberg").load(table).sort(partitionCol, "id").collectAsList()); + } + + private long size(List deleteFiles) { + return deleteFiles.stream().mapToLong(DeleteFile::fileSizeInBytes).sum(); + } + + private List dataFiles(Table table) { + CloseableIterable tasks = table.newScan().includeColumnStats().planFiles(); + return Lists.newArrayList(CloseableIterable.transform(tasks, FileScanTask::file)); + } + + private List deleteFiles(Table table) { + Table deletesTable = + MetadataTableUtils.createMetadataTableInstance(table, MetadataTableType.POSITION_DELETES); + CloseableIterable tasks = deletesTable.newBatchScan().planFiles(); + return Lists.newArrayList( + CloseableIterable.transform(tasks, t -> ((PositionDeletesScanTask) t).file())); + } + + private void checkResult( + Result result, + List rewrittenDeletes, + List newDeletes, + int expectedGroups) { + assertThat(result.rewrittenDeleteFilesCount()) + .as("Rewritten delete files") + .isEqualTo(rewrittenDeletes.size()); + assertThat(result.addedDeleteFilesCount()) + .as("Added delete files") + .isEqualTo(newDeletes.size()); + assertThat(result.rewrittenBytesCount()) + .as("Rewritten delete bytes") + .isEqualTo(size(rewrittenDeletes)); + assertThat(result.addedBytesCount()).as("New Delete byte count").isEqualTo(size(newDeletes)); + + assertThat(result.rewriteResults()).as("Rewritten group count").hasSize(expectedGroups); + assertThat( + result.rewriteResults().stream() + .mapToInt(FileGroupRewriteResult::rewrittenDeleteFilesCount) + .sum()) + .as("Rewritten delete file count in all groups") + .isEqualTo(rewrittenDeletes.size()); + assertThat( + result.rewriteResults().stream() + .mapToInt(FileGroupRewriteResult::addedDeleteFilesCount) + .sum()) + .as("Added delete file count in all groups") + .isEqualTo(newDeletes.size()); + assertThat( + result.rewriteResults().stream() + .mapToLong(FileGroupRewriteResult::rewrittenBytesCount) + .sum()) + .as("Rewritten delete bytes in all groups") + .isEqualTo(size(rewrittenDeletes)); + assertThat( + result.rewriteResults().stream() + .mapToLong(FileGroupRewriteResult::addedBytesCount) + .sum()) + .as("Added delete bytes in all groups") + .isEqualTo(size(newDeletes)); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewritePositionDeleteFilesProcedure.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewritePositionDeleteFilesProcedure.java new file mode 100644 index 000000000000..bb82b63d208d --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewritePositionDeleteFilesProcedure.java @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.SnapshotSummary.ADDED_FILE_SIZE_PROP; +import static org.apache.iceberg.SnapshotSummary.REMOVED_FILE_SIZE_PROP; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.CatalogProperties; +import org.apache.iceberg.EnvironmentContext; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.sql.Encoders; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestRewritePositionDeleteFilesProcedure extends ExtensionsTestBase { + + private void createTable() throws Exception { + createTable(false); + } + + private void createTable(boolean partitioned) throws Exception { + String partitionStmt = partitioned ? "PARTITIONED BY (id)" : ""; + sql( + "CREATE TABLE %s (id bigint, data string) USING iceberg %s TBLPROPERTIES" + + "('format-version'='2', 'write.delete.mode'='merge-on-read')", + tableName, partitionStmt); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(1, "b"), + new SimpleRecord(1, "c"), + new SimpleRecord(2, "d"), + new SimpleRecord(2, "e"), + new SimpleRecord(2, "f"), + new SimpleRecord(3, "g"), + new SimpleRecord(3, "h"), + new SimpleRecord(3, "i"), + new SimpleRecord(4, "j"), + new SimpleRecord(4, "k"), + new SimpleRecord(4, "l"), + new SimpleRecord(5, "m"), + new SimpleRecord(5, "n"), + new SimpleRecord(5, "o"), + new SimpleRecord(6, "p"), + new SimpleRecord(6, "q"), + new SimpleRecord(6, "r")); + spark + .createDataset(records, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + } + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testExpireDeleteFilesAll() throws Exception { + createTable(); + + sql("DELETE FROM %s WHERE id=1", tableName); + sql("DELETE FROM %s WHERE id=2", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(TestHelpers.deleteFiles(table)).hasSize(2); + + List output = + sql( + "CALL %s.system.rewrite_position_delete_files(" + + "table => '%s'," + + "options => map(" + + "'rewrite-all','true'))", + catalogName, tableIdent); + table.refresh(); + + Map snapshotSummary = snapshotSummary(); + assertEquals( + "Should delete 2 delete files and add 1", + ImmutableList.of( + row( + 2, + 1, + Long.valueOf(snapshotSummary.get(REMOVED_FILE_SIZE_PROP)), + Long.valueOf(snapshotSummary.get(ADDED_FILE_SIZE_PROP)))), + output); + + assertThat(TestHelpers.deleteFiles(table)).hasSize(1); + } + + @TestTemplate + public void testExpireDeleteFilesNoOption() throws Exception { + createTable(); + + sql("DELETE FROM %s WHERE id=1", tableName); + sql("DELETE FROM %s WHERE id=2", tableName); + sql("DELETE FROM %s WHERE id=3", tableName); + sql("DELETE FROM %s WHERE id=4", tableName); + sql("DELETE FROM %s WHERE id=5", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(TestHelpers.deleteFiles(table)).hasSize(5); + + List output = + sql( + "CALL %s.system.rewrite_position_delete_files(" + "table => '%s')", + catalogName, tableIdent); + table.refresh(); + + Map snapshotSummary = snapshotSummary(); + assertEquals( + "Should replace 5 delete files with 1", + ImmutableList.of( + row( + 5, + 1, + Long.valueOf(snapshotSummary.get(REMOVED_FILE_SIZE_PROP)), + Long.valueOf(snapshotSummary.get(ADDED_FILE_SIZE_PROP)))), + output); + } + + @TestTemplate + public void testExpireDeleteFilesFilter() throws Exception { + createTable(true); + + sql("DELETE FROM %s WHERE id = 1 and data='a'", tableName); + sql("DELETE FROM %s WHERE id = 1 and data='b'", tableName); + sql("DELETE FROM %s WHERE id = 2 and data='d'", tableName); + sql("DELETE FROM %s WHERE id = 2 and data='e'", tableName); + sql("DELETE FROM %s WHERE id = 3 and data='g'", tableName); + sql("DELETE FROM %s WHERE id = 3 and data='h'", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(TestHelpers.deleteFiles(table)).hasSize(6); + + List output = + sql( + "CALL %s.system.rewrite_position_delete_files(" + + "table => '%s'," + // data filter is ignored as it cannot be applied to position deletes + + "where => 'id IN (1, 2) AND data=\"bar\"'," + + "options => map(" + + "'rewrite-all','true'))", + catalogName, tableIdent); + table.refresh(); + + Map snapshotSummary = snapshotSummary(); + assertEquals( + "Should delete 4 delete files and add 2", + ImmutableList.of( + row( + 4, + 2, + Long.valueOf(snapshotSummary.get(REMOVED_FILE_SIZE_PROP)), + Long.valueOf(snapshotSummary.get(ADDED_FILE_SIZE_PROP)))), + output); + + assertThat(TestHelpers.deleteFiles(table)).hasSize(4); + } + + @TestTemplate + public void testInvalidOption() throws Exception { + createTable(); + + assertThatThrownBy( + () -> + sql( + "CALL %s.system.rewrite_position_delete_files(" + + "table => '%s'," + + "options => map(" + + "'foo', 'bar'))", + catalogName, tableIdent)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining( + "Cannot use options [foo], they are not supported by the action or the rewriter BIN-PACK"); + } + + @TestTemplate + public void testRewriteWithUntranslatedOrUnconvertedFilter() throws Exception { + createTable(); + assertThatThrownBy( + () -> + sql( + "CALL %s.system.rewrite_position_delete_files(table => '%s', where => 'substr(encode(data, \"utf-8\"), 2) = \"fo\"')", + catalogName, tableIdent)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Cannot translate Spark expression"); + + assertThatThrownBy( + () -> + sql( + "CALL %s.system.rewrite_position_delete_files(table => '%s', where => 'substr(data, 2) = \"fo\"')", + catalogName, tableIdent)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Cannot convert Spark filter"); + } + + @TestTemplate + public void testRewriteSummary() throws Exception { + createTable(); + sql("DELETE FROM %s WHERE id=1", tableName); + + sql( + "CALL %s.system.rewrite_position_delete_files(" + + "table => '%s'," + + "options => map(" + + "'rewrite-all','true'))", + catalogName, tableIdent); + + Map summary = snapshotSummary(); + assertThat(summary) + .containsKey(CatalogProperties.APP_ID) + .containsEntry(EnvironmentContext.ENGINE_NAME, "spark") + .hasEntrySatisfying( + EnvironmentContext.ENGINE_VERSION, v -> assertThat(v).startsWith("3.5")); + } + + private Map snapshotSummary() { + return validationCatalog.loadTable(tableIdent).currentSnapshot().summary(); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRollbackToSnapshotProcedure.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRollbackToSnapshotProcedure.java new file mode 100644 index 000000000000..43df78bf766d --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRollbackToSnapshotProcedure.java @@ -0,0 +1,286 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.assertj.core.api.Assumptions.assumeThat; + +import java.util.List; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestRollbackToSnapshotProcedure extends ExtensionsTestBase { + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testRollbackToSnapshotUsingPositionalArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + List output = + sql( + "CALL %s.system.rollback_to_snapshot('%s', %dL)", + catalogName, tableIdent, firstSnapshot.snapshotId()); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals( + "Rollback must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void testRollbackToSnapshotUsingNamedArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + List output = + sql( + "CALL %s.system.rollback_to_snapshot(snapshot_id => %dL, table => '%s')", + catalogName, firstSnapshot.snapshotId(), tableIdent); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals( + "Rollback must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void testRollbackToSnapshotRefreshesRelationCache() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + Dataset query = spark.sql("SELECT * FROM " + tableName + " WHERE id = 1"); + query.createOrReplaceTempView("tmp"); + + spark.sql("CACHE TABLE tmp"); + + assertEquals( + "View should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM tmp")); + + List output = + sql( + "CALL %s.system.rollback_to_snapshot(table => '%s', snapshot_id => %dL)", + catalogName, tableIdent, firstSnapshot.snapshotId()); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals( + "View cache must be invalidated", ImmutableList.of(row(1L, "a")), sql("SELECT * FROM tmp")); + + sql("UNCACHE TABLE tmp"); + } + + @TestTemplate + public void testRollbackToSnapshotWithQuotedIdentifiers() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + StringBuilder quotedNamespaceBuilder = new StringBuilder(); + for (String level : tableIdent.namespace().levels()) { + quotedNamespaceBuilder.append("`"); + quotedNamespaceBuilder.append(level); + quotedNamespaceBuilder.append("`"); + } + String quotedNamespace = quotedNamespaceBuilder.toString(); + + List output = + sql( + "CALL %s.system.rollback_to_snapshot('%s', %d)", + catalogName, + quotedNamespace + ".`" + tableIdent.name() + "`", + firstSnapshot.snapshotId()); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals( + "Rollback must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void testRollbackToSnapshotWithoutExplicitCatalog() { + assumeThat(catalogName).as("Working only with the session catalog").isEqualTo("spark_catalog"); + + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + // use camel case intentionally to test case sensitivity + List output = + sql("CALL SyStEm.rOLlBaCk_to_SnApShOt('%s', %dL)", tableIdent, firstSnapshot.snapshotId()); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals( + "Rollback must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void testRollbackToInvalidSnapshot() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + assertThatThrownBy( + () -> sql("CALL %s.system.rollback_to_snapshot('%s', -1L)", catalogName, tableIdent)) + .isInstanceOf(ValidationException.class) + .hasMessage("Cannot roll back to unknown snapshot id: -1"); + } + + @TestTemplate + public void testInvalidRollbackToSnapshotCases() { + assertThatThrownBy( + () -> + sql( + "CALL %s.system.rollback_to_snapshot(namespace => 'n1', table => 't', 1L)", + catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessage("Named and positional arguments cannot be mixed"); + + assertThatThrownBy(() -> sql("CALL %s.custom.rollback_to_snapshot('n', 't', 1L)", catalogName)) + .isInstanceOf(ParseException.class) + .satisfies( + exception -> { + ParseException parseException = (ParseException) exception; + assertThat(parseException.getErrorClass()).isEqualTo("PARSE_SYNTAX_ERROR"); + assertThat(parseException.getMessageParameters().get("error")).isEqualTo("'CALL'"); + }); + + assertThatThrownBy(() -> sql("CALL %s.system.rollback_to_snapshot('t')", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessage("Missing required parameters: [snapshot_id]"); + + assertThatThrownBy(() -> sql("CALL %s.system.rollback_to_snapshot(1L)", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessage("Missing required parameters: [snapshot_id]"); + + assertThatThrownBy(() -> sql("CALL %s.system.rollback_to_snapshot(table => 't')", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessage("Missing required parameters: [snapshot_id]"); + + assertThatThrownBy(() -> sql("CALL %s.system.rollback_to_snapshot('t', 2.2)", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessage("Wrong arg type for snapshot_id: cannot cast DecimalType(2,1) to LongType"); + + assertThatThrownBy(() -> sql("CALL %s.system.rollback_to_snapshot('', 1L)", catalogName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot handle an empty identifier for argument table"); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRollbackToTimestampProcedure.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRollbackToTimestampProcedure.java new file mode 100644 index 000000000000..ae35b9f1817c --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRollbackToTimestampProcedure.java @@ -0,0 +1,335 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.assertj.core.api.Assumptions.assumeThat; + +import java.sql.Timestamp; +import java.time.Instant; +import java.time.LocalDateTime; +import java.util.List; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestRollbackToTimestampProcedure extends ExtensionsTestBase { + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testRollbackToTimestampUsingPositionalArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + String firstSnapshotTimestamp = LocalDateTime.now().toString(); + + waitUntilAfter(firstSnapshot.timestampMillis()); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + List output = + sql( + "CALL %s.system.rollback_to_timestamp('%s',TIMESTAMP '%s')", + catalogName, tableIdent, firstSnapshotTimestamp); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals( + "Rollback must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void testRollbackToTimestampUsingNamedArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + String firstSnapshotTimestamp = LocalDateTime.now().toString(); + + waitUntilAfter(firstSnapshot.timestampMillis()); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + List output = + sql( + "CALL %s.system.rollback_to_timestamp(timestamp => TIMESTAMP '%s', table => '%s')", + catalogName, firstSnapshotTimestamp, tableIdent); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals( + "Rollback must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void testRollbackToTimestampRefreshesRelationCache() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + String firstSnapshotTimestamp = LocalDateTime.now().toString(); + + waitUntilAfter(firstSnapshot.timestampMillis()); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + Dataset query = spark.sql("SELECT * FROM " + tableName + " WHERE id = 1"); + query.createOrReplaceTempView("tmp"); + + spark.sql("CACHE TABLE tmp"); + + assertEquals( + "View should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM tmp")); + + List output = + sql( + "CALL %s.system.rollback_to_timestamp(table => '%s', timestamp => TIMESTAMP '%s')", + catalogName, tableIdent, firstSnapshotTimestamp); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals( + "View cache must be invalidated", ImmutableList.of(row(1L, "a")), sql("SELECT * FROM tmp")); + + sql("UNCACHE TABLE tmp"); + } + + @TestTemplate + public void testRollbackToTimestampWithQuotedIdentifiers() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + String firstSnapshotTimestamp = LocalDateTime.now().toString(); + + waitUntilAfter(firstSnapshot.timestampMillis()); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + StringBuilder quotedNamespaceBuilder = new StringBuilder(); + for (String level : tableIdent.namespace().levels()) { + quotedNamespaceBuilder.append("`"); + quotedNamespaceBuilder.append(level); + quotedNamespaceBuilder.append("`"); + } + String quotedNamespace = quotedNamespaceBuilder.toString(); + + List output = + sql( + "CALL %s.system.rollback_to_timestamp('%s', TIMESTAMP '%s')", + catalogName, quotedNamespace + ".`" + tableIdent.name() + "`", firstSnapshotTimestamp); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals( + "Rollback must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void testRollbackToTimestampWithoutExplicitCatalog() { + assumeThat(catalogName).as("Working only with the session catalog").isEqualTo("spark_catalog"); + + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + String firstSnapshotTimestamp = LocalDateTime.now().toString(); + + waitUntilAfter(firstSnapshot.timestampMillis()); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + // use camel case intentionally to test case sensitivity + List output = + sql( + "CALL SyStEm.rOLlBaCk_to_TiMeStaMp('%s', TIMESTAMP '%s')", + tableIdent, firstSnapshotTimestamp); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals( + "Rollback must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void testRollbackToTimestampBeforeOrEqualToOldestSnapshot() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + Timestamp beforeFirstSnapshot = + Timestamp.from(Instant.ofEpochMilli(firstSnapshot.timestampMillis() - 1)); + Timestamp exactFirstSnapshot = + Timestamp.from(Instant.ofEpochMilli(firstSnapshot.timestampMillis())); + + assertThatThrownBy( + () -> + sql( + "CALL %s.system.rollback_to_timestamp(timestamp => TIMESTAMP '%s', table => '%s')", + catalogName, beforeFirstSnapshot, tableIdent)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Cannot roll back, no valid snapshot older than: %s", + beforeFirstSnapshot.toInstant().toEpochMilli()); + + assertThatThrownBy( + () -> + sql( + "CALL %s.system.rollback_to_timestamp(timestamp => TIMESTAMP '%s', table => '%s')", + catalogName, exactFirstSnapshot, tableIdent)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Cannot roll back, no valid snapshot older than: %s", + exactFirstSnapshot.toInstant().toEpochMilli()); + } + + @TestTemplate + public void testInvalidRollbackToTimestampCases() { + String timestamp = "TIMESTAMP '2007-12-03T10:15:30'"; + + assertThatThrownBy( + () -> + sql( + "CALL %s.system.rollback_to_timestamp(namespace => 'n1', 't', %s)", + catalogName, timestamp)) + .isInstanceOf(AnalysisException.class) + .hasMessage("Named and positional arguments cannot be mixed"); + + assertThatThrownBy( + () -> sql("CALL %s.custom.rollback_to_timestamp('n', 't', %s)", catalogName, timestamp)) + .isInstanceOf(ParseException.class) + .satisfies( + exception -> { + ParseException parseException = (ParseException) exception; + assertThat(parseException.getErrorClass()).isEqualTo("PARSE_SYNTAX_ERROR"); + assertThat(parseException.getMessageParameters().get("error")).isEqualTo("'CALL'"); + }); + + assertThatThrownBy(() -> sql("CALL %s.system.rollback_to_timestamp('t')", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessage("Missing required parameters: [timestamp]"); + + assertThatThrownBy( + () -> + sql( + "CALL %s.system.rollback_to_timestamp(timestamp => %s)", + catalogName, timestamp)) + .isInstanceOf(AnalysisException.class) + .hasMessage("Missing required parameters: [table]"); + + assertThatThrownBy(() -> sql("CALL %s.system.rollback_to_timestamp(table => 't')", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessage("Missing required parameters: [timestamp]"); + + assertThatThrownBy( + () -> + sql( + "CALL %s.system.rollback_to_timestamp('n', 't', %s, 1L)", + catalogName, timestamp)) + .isInstanceOf(AnalysisException.class) + .hasMessage("Too many arguments for procedure"); + + assertThatThrownBy(() -> sql("CALL %s.system.rollback_to_timestamp('t', 2.2)", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessage("Wrong arg type for timestamp: cannot cast DecimalType(2,1) to TimestampType"); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetCurrentSnapshotProcedure.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetCurrentSnapshotProcedure.java new file mode 100644 index 000000000000..4c34edef5d25 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetCurrentSnapshotProcedure.java @@ -0,0 +1,293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.TableProperties.WRITE_AUDIT_PUBLISH_ENABLED; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.assertj.core.api.Assumptions.assumeThat; + +import java.util.List; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestSetCurrentSnapshotProcedure extends ExtensionsTestBase { + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testSetCurrentSnapshotUsingPositionalArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + List output = + sql( + "CALL %s.system.set_current_snapshot('%s', %dL)", + catalogName, tableIdent, firstSnapshot.snapshotId()); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals( + "Set must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void testSetCurrentSnapshotUsingNamedArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + List output = + sql( + "CALL %s.system.set_current_snapshot(snapshot_id => %dL, table => '%s')", + catalogName, firstSnapshot.snapshotId(), tableIdent); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals( + "Set must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void testSetCurrentSnapshotWap() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'true')", tableName, WRITE_AUDIT_PUBLISH_ENABLED); + + spark.conf().set("spark.wap.id", "1"); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should not see rows from staged snapshot", + ImmutableList.of(), + sql("SELECT * FROM %s", tableName)); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot wapSnapshot = Iterables.getOnlyElement(table.snapshots()); + + List output = + sql( + "CALL %s.system.set_current_snapshot(table => '%s', snapshot_id => %dL)", + catalogName, tableIdent, wapSnapshot.snapshotId()); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(null, wapSnapshot.snapshotId())), + output); + + assertEquals( + "Current snapshot must be set correctly", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s", tableName)); + } + + @TestTemplate + public void tesSetCurrentSnapshotWithoutExplicitCatalog() { + assumeThat(catalogName).as("Working only with the session catalog").isEqualTo("spark_catalog"); + + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + // use camel case intentionally to test case sensitivity + List output = + sql("CALL SyStEm.sEt_cuRrEnT_sNaPsHot('%s', %dL)", tableIdent, firstSnapshot.snapshotId()); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals( + "Set must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void testSetCurrentSnapshotToInvalidSnapshot() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + assertThatThrownBy( + () -> sql("CALL %s.system.set_current_snapshot('%s', -1L)", catalogName, tableIdent)) + .isInstanceOf(ValidationException.class) + .hasMessage("Cannot roll back to unknown snapshot id: -1"); + } + + @TestTemplate + public void testInvalidRollbackToSnapshotCases() { + assertThatThrownBy( + () -> + sql( + "CALL %s.system.set_current_snapshot(namespace => 'n1', table => 't', 1L)", + catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessage("Named and positional arguments cannot be mixed"); + + assertThatThrownBy(() -> sql("CALL %s.custom.set_current_snapshot('n', 't', 1L)", catalogName)) + .isInstanceOf(ParseException.class) + .satisfies( + exception -> { + ParseException parseException = (ParseException) exception; + assertThat(parseException.getErrorClass()).isEqualTo("PARSE_SYNTAX_ERROR"); + assertThat(parseException.getMessageParameters().get("error")).isEqualTo("'CALL'"); + }); + + assertThatThrownBy(() -> sql("CALL %s.system.set_current_snapshot('t')", catalogName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Either snapshot_id or ref must be provided, not both"); + + assertThatThrownBy(() -> sql("CALL %s.system.set_current_snapshot(1L)", catalogName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot parse identifier for arg table: 1"); + + assertThatThrownBy( + () -> sql("CALL %s.system.set_current_snapshot(snapshot_id => 1L)", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessage("Missing required parameters: [table]"); + + assertThatThrownBy(() -> sql("CALL %s.system.set_current_snapshot(table => 't')", catalogName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Either snapshot_id or ref must be provided, not both"); + + assertThatThrownBy(() -> sql("CALL %s.system.set_current_snapshot('t', 2.2)", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessage("Wrong arg type for snapshot_id: cannot cast DecimalType(2,1) to LongType"); + + assertThatThrownBy(() -> sql("CALL %s.system.set_current_snapshot('', 1L)", catalogName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot handle an empty identifier for argument table"); + + assertThatThrownBy( + () -> + sql( + "CALL %s.system.set_current_snapshot(table => 't', snapshot_id => 1L, ref => 's1')", + catalogName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Either snapshot_id or ref must be provided, not both"); + } + + @TestTemplate + public void testSetCurrentSnapshotToRef() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + String ref = "s1"; + sql("ALTER TABLE %s CREATE TAG %s", tableName, ref); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + List output = + sql( + "CALL %s.system.set_current_snapshot(table => '%s', ref => '%s')", + catalogName, tableIdent, ref); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals( + "Set must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + String notExistRef = "s2"; + assertThatThrownBy( + () -> + sql( + "CALL %s.system.set_current_snapshot(table => '%s', ref => '%s')", + catalogName, tableIdent, notExistRef)) + .isInstanceOf(ValidationException.class) + .hasMessage("Cannot find matching snapshot ID for ref " + notExistRef); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetWriteDistributionAndOrdering.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetWriteDistributionAndOrdering.java new file mode 100644 index 000000000000..b8547772da67 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetWriteDistributionAndOrdering.java @@ -0,0 +1,351 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.expressions.Expressions.bucket; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import org.apache.iceberg.NullOrder; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.spark.sql.internal.SQLConf; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestSetWriteDistributionAndOrdering extends ExtensionsTestBase { + + @AfterEach + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testSetWriteOrderByColumn() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.sortOrder().isUnsorted()).as("Table should start unsorted").isTrue(); + + sql("ALTER TABLE %s WRITE ORDERED BY category, id", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + assertThat(distributionMode).as("Distribution mode must match").isEqualTo("range"); + + SortOrder expected = + SortOrder.builderFor(table.schema()) + .withOrderId(1) + .asc("category", NullOrder.NULLS_FIRST) + .asc("id", NullOrder.NULLS_FIRST) + .build(); + assertThat(table.sortOrder()).as("Should have expected order").isEqualTo(expected); + } + + @TestTemplate + public void testSetWriteOrderWithCaseSensitiveColumnNames() { + sql( + "CREATE TABLE %s (Id bigint NOT NULL, Category string, ts timestamp, data string) USING iceberg", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.sortOrder().isUnsorted()).as("Table should start unsorted").isTrue(); + sql("SET %s=true", SQLConf.CASE_SENSITIVE().key()); + assertThatThrownBy( + () -> { + sql("ALTER TABLE %s WRITE ORDERED BY category, id", tableName); + }) + .isInstanceOf(ValidationException.class) + .hasMessageContaining("Cannot find field 'category' in struct"); + + sql("SET %s=false", SQLConf.CASE_SENSITIVE().key()); + sql("ALTER TABLE %s WRITE ORDERED BY category, id", tableName); + table = validationCatalog.loadTable(tableIdent); + SortOrder expected = + SortOrder.builderFor(table.schema()).withOrderId(1).asc("Category").asc("Id").build(); + assertThat(table.sortOrder()).as("Should have expected order").isEqualTo(expected); + } + + @TestTemplate + public void testSetWriteOrderByColumnWithDirection() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.sortOrder().isUnsorted()).as("Table should start unsorted").isTrue(); + + sql("ALTER TABLE %s WRITE ORDERED BY category ASC, id DESC", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + assertThat(distributionMode).as("Distribution mode must match").isEqualTo("range"); + + SortOrder expected = + SortOrder.builderFor(table.schema()) + .withOrderId(1) + .asc("category", NullOrder.NULLS_FIRST) + .desc("id", NullOrder.NULLS_LAST) + .build(); + assertThat(table.sortOrder()).as("Should have expected order").isEqualTo(expected); + } + + @TestTemplate + public void testSetWriteOrderByColumnWithDirectionAndNullOrder() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.sortOrder().isUnsorted()).as("Table should start unsorted").isTrue(); + + sql("ALTER TABLE %s WRITE ORDERED BY category ASC NULLS LAST, id DESC NULLS FIRST", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + assertThat(distributionMode).as("Distribution mode must match").isEqualTo("range"); + + SortOrder expected = + SortOrder.builderFor(table.schema()) + .withOrderId(1) + .asc("category", NullOrder.NULLS_LAST) + .desc("id", NullOrder.NULLS_FIRST) + .build(); + assertThat(table.sortOrder()).as("Should have expected order").isEqualTo(expected); + } + + @TestTemplate + public void testSetWriteOrderByTransform() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.sortOrder().isUnsorted()).isTrue(); + + sql("ALTER TABLE %s WRITE ORDERED BY category DESC, bucket(16, id), id", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + assertThat(distributionMode).as("Distribution mode must match").isEqualTo("range"); + + SortOrder expected = + SortOrder.builderFor(table.schema()) + .withOrderId(1) + .desc("category") + .asc(bucket("id", 16)) + .asc("id") + .build(); + assertThat(table.sortOrder()).as("Should have expected order").isEqualTo(expected); + } + + @TestTemplate + public void testSetWriteUnordered() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.sortOrder().isUnsorted()).as("Table should start unsorted").isTrue(); + + sql("ALTER TABLE %s WRITE ORDERED BY category DESC, bucket(16, id), id", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + assertThat(distributionMode).as("Distribution mode must match").isEqualTo("range"); + + assertThat(table.sortOrder()).as("Table must be sorted").isNotEqualTo(SortOrder.unsorted()); + + sql("ALTER TABLE %s WRITE UNORDERED", tableName); + + table.refresh(); + + String newDistributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + assertThat(newDistributionMode).as("New distribution mode must match").isEqualTo("none"); + + assertThat(table.sortOrder()).as("New sort order must match").isEqualTo(SortOrder.unsorted()); + } + + @TestTemplate + public void testSetWriteLocallyOrdered() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.sortOrder().isUnsorted()).as("Table should start unsorted").isTrue(); + + sql("ALTER TABLE %s WRITE LOCALLY ORDERED BY category DESC, bucket(16, id), id", tableName); + + table.refresh(); + + assertThat(table.properties().containsKey(TableProperties.WRITE_DISTRIBUTION_MODE)).isFalse(); + + SortOrder expected = + SortOrder.builderFor(table.schema()) + .withOrderId(1) + .desc("category") + .asc(bucket("id", 16)) + .asc("id") + .build(); + assertThat(table.sortOrder()).as("Sort order must match").isEqualTo(expected); + } + + @TestTemplate + public void testSetWriteLocallyOrderedToPartitionedTable() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string) USING iceberg PARTITIONED BY (id)", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.sortOrder().isUnsorted()).as("Table should start unsorted").isTrue(); + + sql("ALTER TABLE %s WRITE LOCALLY ORDERED BY category DESC", tableName); + + table.refresh(); + + assertThat(table.properties().containsKey(TableProperties.WRITE_DISTRIBUTION_MODE)).isFalse(); + + SortOrder expected = + SortOrder.builderFor(table.schema()).withOrderId(1).desc("category").build(); + assertThat(table.sortOrder()).as("Sort order must match").isEqualTo(expected); + } + + @TestTemplate + public void testSetWriteDistributedByWithSort() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string) USING iceberg PARTITIONED BY (category)", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.sortOrder().isUnsorted()).as("Table should start unsorted").isTrue(); + + sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION ORDERED BY id", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + assertThat(distributionMode).as("Distribution mode must match").isEqualTo("hash"); + + SortOrder expected = SortOrder.builderFor(table.schema()).withOrderId(1).asc("id").build(); + assertThat(table.sortOrder()).as("Sort order must match").isEqualTo(expected); + } + + @TestTemplate + public void testSetWriteDistributedByWithLocalSort() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string) USING iceberg PARTITIONED BY (category)", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.sortOrder().isUnsorted()).as("Table should start unsorted").isTrue(); + + sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION LOCALLY ORDERED BY id", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + assertThat(distributionMode).as("Distribution mode must match").isEqualTo("hash"); + + SortOrder expected = SortOrder.builderFor(table.schema()).withOrderId(1).asc("id").build(); + assertThat(table.sortOrder()).as("Sort order must match").isEqualTo(expected); + + sql("ALTER TABLE %s WRITE LOCALLY ORDERED BY id", tableName); + + table.refresh(); + + String newDistributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + assertThat(newDistributionMode).as("Distribution mode must match").isEqualTo(distributionMode); + } + + @TestTemplate + public void testSetWriteDistributedByAndUnordered() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string) USING iceberg PARTITIONED BY (category)", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.sortOrder().isUnsorted()).as("Table should start unsorted").isTrue(); + + sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION UNORDERED", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + assertThat(distributionMode).as("Distribution mode must match").isEqualTo("hash"); + + assertThat(table.sortOrder()).as("Sort order must match").isEqualTo(SortOrder.unsorted()); + } + + @TestTemplate + public void testSetWriteDistributedByOnly() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string) USING iceberg PARTITIONED BY (category)", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.sortOrder().isUnsorted()).as("Table should start unsorted").isTrue(); + + sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION UNORDERED", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + assertThat(distributionMode).as("Distribution mode must match").isEqualTo("hash"); + + assertThat(table.sortOrder()).as("Sort order must match").isEqualTo(SortOrder.unsorted()); + } + + @TestTemplate + public void testSetWriteDistributedAndUnorderedInverted() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string) USING iceberg PARTITIONED BY (category)", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.sortOrder().isUnsorted()).as("Table should start unsorted").isTrue(); + + sql("ALTER TABLE %s WRITE UNORDERED DISTRIBUTED BY PARTITION", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + assertThat(distributionMode).as("Distribution mode must match").isEqualTo("hash"); + + assertThat(table.sortOrder()).as("Sort order must match").isEqualTo(SortOrder.unsorted()); + } + + @TestTemplate + public void testSetWriteDistributedAndLocallyOrderedInverted() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string) USING iceberg PARTITIONED BY (category)", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.sortOrder().isUnsorted()).as("Table should start unsorted").isTrue(); + + sql("ALTER TABLE %s WRITE ORDERED BY id DISTRIBUTED BY PARTITION", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + assertThat(distributionMode).as("Distribution mode must match").isEqualTo("hash"); + + SortOrder expected = SortOrder.builderFor(table.schema()).withOrderId(1).asc("id").build(); + assertThat(table.sortOrder()).as("Sort order must match").isEqualTo(expected); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSnapshotTableProcedure.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSnapshotTableProcedure.java new file mode 100644 index 000000000000..6caff28bb16c --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSnapshotTableProcedure.java @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assumptions.assumeThat; + +import java.io.IOException; +import java.nio.file.Files; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.spark.sql.AnalysisException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestSnapshotTableProcedure extends ExtensionsTestBase { + private static final String SOURCE_NAME = "spark_catalog.default.source"; + + // Currently we can only Snapshot only out of the Spark Session Catalog + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS %s PURGE", SOURCE_NAME); + } + + @TestTemplate + public void testSnapshot() throws IOException { + String location = Files.createTempDirectory(temp, "junit").toFile().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + SOURCE_NAME, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", SOURCE_NAME); + Object result = + scalarSql("CALL %s.system.snapshot('%s', '%s')", catalogName, SOURCE_NAME, tableName); + + assertThat(result).as("Should have added one file").isEqualTo(1L); + + Table createdTable = validationCatalog.loadTable(tableIdent); + String tableLocation = createdTable.location(); + assertThat(tableLocation) + .as("Table should not have the original location") + .isNotEqualTo(location); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void testSnapshotWithProperties() throws IOException { + String location = Files.createTempDirectory(temp, "junit").toFile().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + SOURCE_NAME, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", SOURCE_NAME); + Object result = + scalarSql( + "CALL %s.system.snapshot(source_table => '%s', table => '%s', properties => map('foo','bar'))", + catalogName, SOURCE_NAME, tableName); + + assertThat(result).as("Should have added one file").isEqualTo(1L); + + Table createdTable = validationCatalog.loadTable(tableIdent); + + String tableLocation = createdTable.location(); + assertThat(tableLocation) + .as("Table should not have the original location") + .isNotEqualTo(location); + + Map props = createdTable.properties(); + assertThat(props).as("Should have extra property set").containsEntry("foo", "bar"); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void testSnapshotWithAlternateLocation() throws IOException { + assumeThat(catalogName) + .as("No Snapshoting with Alternate locations with Hadoop Catalogs") + .doesNotContain("hadoop"); + String location = Files.createTempDirectory(temp, "junit").toFile().toString(); + String snapshotLocation = Files.createTempDirectory(temp, "junit").toFile().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + SOURCE_NAME, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", SOURCE_NAME); + Object[] result = + sql( + "CALL %s.system.snapshot(source_table => '%s', table => '%s', location => '%s')", + catalogName, SOURCE_NAME, tableName, snapshotLocation) + .get(0); + + assertThat(result[0]).as("Should have added one file").isEqualTo(1L); + + String storageLocation = validationCatalog.loadTable(tableIdent).location(); + assertThat(storageLocation) + .as("Snapshot should be made at specified location") + .isEqualTo(snapshotLocation); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void testDropTable() throws IOException { + String location = Files.createTempDirectory(temp, "junit").toFile().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + SOURCE_NAME, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", SOURCE_NAME); + + Object result = + scalarSql("CALL %s.system.snapshot('%s', '%s')", catalogName, SOURCE_NAME, tableName); + assertThat(result).as("Should have added one file").isEqualTo(1L); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s", tableName)); + + sql("DROP TABLE %s", tableName); + + assertEquals( + "Source table should be intact", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s", SOURCE_NAME)); + } + + @TestTemplate + public void testSnapshotWithConflictingProps() throws IOException { + String location = Files.createTempDirectory(temp, "junit").toFile().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + SOURCE_NAME, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", SOURCE_NAME); + + Object result = + scalarSql( + "CALL %s.system.snapshot(" + + "source_table => '%s'," + + "table => '%s'," + + "properties => map('%s', 'true', 'snapshot', 'false'))", + catalogName, SOURCE_NAME, tableName, TableProperties.GC_ENABLED); + assertThat(result).as("Should have added one file").isEqualTo(1L); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s", tableName)); + + Table table = validationCatalog.loadTable(tableIdent); + Map props = table.properties(); + assertThat(props).as("Should override user value").containsEntry("snapshot", "true"); + assertThat(props) + .as("Should override user value") + .containsEntry(TableProperties.GC_ENABLED, "false"); + } + + @TestTemplate + public void testInvalidSnapshotsCases() throws IOException { + String location = Files.createTempDirectory(temp, "junit").toFile().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + SOURCE_NAME, location); + + assertThatThrownBy(() -> sql("CALL %s.system.snapshot('foo')", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessage("Missing required parameters: [table]"); + + assertThatThrownBy( + () -> sql("CALL %s.system.snapshot('n', 't', map('foo', 'bar'))", catalogName)) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith("Wrong arg type for location"); + + assertThatThrownBy( + () -> + sql( + "CALL %s.system.snapshot('%s', 'fable', 'loc', map(2, 1, 1))", + catalogName, SOURCE_NAME)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + "The `map` requires 2n (n > 0) parameters but the actual number is 3"); + + assertThatThrownBy(() -> sql("CALL %s.system.snapshot('', 'dest')", catalogName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot handle an empty identifier for argument source_table"); + + assertThatThrownBy(() -> sql("CALL %s.system.snapshot('src', '')", catalogName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot handle an empty identifier for argument table"); + } + + @TestTemplate + public void testSnapshotWithParallelism() throws IOException { + String location = Files.createTempDirectory(temp, "junit").toFile().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + SOURCE_NAME, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", SOURCE_NAME); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", SOURCE_NAME); + + List result = + sql( + "CALL %s.system.snapshot(source_table => '%s', table => '%s', parallelism => %d)", + catalogName, SOURCE_NAME, tableName, 2); + assertEquals("Procedure output must match", ImmutableList.of(row(2L)), result); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(2L, "b")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void testSnapshotWithInvalidParallelism() throws IOException { + String location = Files.createTempDirectory(temp, "junit").toFile().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + SOURCE_NAME, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", SOURCE_NAME); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", SOURCE_NAME); + + assertThatThrownBy( + () -> + sql( + "CALL %s.system.snapshot(source_table => '%s', table => '%s', parallelism => %d)", + catalogName, SOURCE_NAME, tableName, -1)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Parallelism should be larger than 0"); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestStoragePartitionedJoinsInRowLevelOperations.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestStoragePartitionedJoinsInRowLevelOperations.java new file mode 100644 index 000000000000..ce609450c097 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestStoragePartitionedJoinsInRowLevelOperations.java @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.RowLevelOperationMode.COPY_ON_WRITE; +import static org.apache.iceberg.RowLevelOperationMode.MERGE_ON_READ; +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.Map; +import org.apache.commons.lang3.StringUtils; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.spark.sql.execution.SparkPlan; +import org.apache.spark.sql.internal.SQLConf; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestStoragePartitionedJoinsInRowLevelOperations extends ExtensionsTestBase { + + private static final String OTHER_TABLE_NAME = "other_table"; + + // open file cost and split size are set as 16 MB to produce a split per file + private static final Map COMMON_TABLE_PROPERTIES = + ImmutableMap.of( + TableProperties.FORMAT_VERSION, + "2", + TableProperties.SPLIT_SIZE, + "16777216", + TableProperties.SPLIT_OPEN_FILE_COST, + "16777216"); + + // only v2 bucketing and preserve data grouping properties have to be enabled to trigger SPJ + // other properties are only to simplify testing and validation + private static final Map ENABLED_SPJ_SQL_CONF = + ImmutableMap.of( + SQLConf.V2_BUCKETING_ENABLED().key(), + "true", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED().key(), + "true", + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION().key(), + "false", + SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), + "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD().key(), + "-1", + SparkSQLProperties.PRESERVE_DATA_GROUPING, + "true"); + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + public static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.HIVE.catalogName(), + SparkCatalogConfig.HIVE.implementation(), + SparkCatalogConfig.HIVE.properties() + } + }; + } + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS %s", tableName(OTHER_TABLE_NAME)); + } + + @TestTemplate + public void testCopyOnWriteDeleteWithoutShuffles() { + checkDelete(COPY_ON_WRITE); + } + + @TestTemplate + public void testMergeOnReadDeleteWithoutShuffles() { + checkDelete(MERGE_ON_READ); + } + + private void checkDelete(RowLevelOperationMode mode) { + String createTableStmt = + "CREATE TABLE %s (id INT, salary INT, dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (dep) " + + "TBLPROPERTIES (%s)"; + + sql(createTableStmt, tableName, tablePropsAsString(COMMON_TABLE_PROPERTIES)); + + append(tableName, "{ \"id\": 1, \"salary\": 100, \"dep\": \"hr\" }"); + append(tableName, "{ \"id\": 2, \"salary\": 200, \"dep\": \"hr\" }"); + append(tableName, "{ \"id\": 3, \"salary\": 300, \"dep\": \"hr\" }"); + append(tableName, "{ \"id\": 4, \"salary\": 400, \"dep\": \"hardware\" }"); + + sql(createTableStmt, tableName(OTHER_TABLE_NAME), tablePropsAsString(COMMON_TABLE_PROPERTIES)); + + append(tableName(OTHER_TABLE_NAME), "{ \"id\": 1, \"salary\": 110, \"dep\": \"hr\" }"); + append(tableName(OTHER_TABLE_NAME), "{ \"id\": 5, \"salary\": 500, \"dep\": \"hr\" }"); + + Map deleteTableProps = + ImmutableMap.of( + TableProperties.DELETE_MODE, + mode.modeName(), + TableProperties.DELETE_DISTRIBUTION_MODE, + "none"); + + sql("ALTER TABLE %s SET TBLPROPERTIES(%s)", tableName, tablePropsAsString(deleteTableProps)); + + withSQLConf( + ENABLED_SPJ_SQL_CONF, + () -> { + SparkPlan plan = + executeAndKeepPlan( + "DELETE FROM %s t WHERE " + + "EXISTS (SELECT 1 FROM %s s WHERE t.id = s.id AND t.dep = s.dep)", + tableName, tableName(OTHER_TABLE_NAME)); + String planAsString = plan.toString(); + if (mode == COPY_ON_WRITE) { + int actualNumShuffles = StringUtils.countMatches(planAsString, "Exchange"); + assertThat(actualNumShuffles).as("Should be 1 shuffle with SPJ").isEqualTo(1); + assertThat(planAsString).contains("Exchange hashpartitioning(_file"); + } else { + assertThat(planAsString).doesNotContain("Exchange"); + } + }); + + ImmutableList expectedRows = + ImmutableList.of( + row(2, 200, "hr"), // remaining + row(3, 300, "hr"), // remaining + row(4, 400, "hardware")); // remaining + + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id, salary", tableName)); + } + + @TestTemplate + public void testCopyOnWriteUpdateWithoutShuffles() { + checkUpdate(COPY_ON_WRITE); + } + + @TestTemplate + public void testMergeOnReadUpdateWithoutShuffles() { + checkUpdate(MERGE_ON_READ); + } + + private void checkUpdate(RowLevelOperationMode mode) { + String createTableStmt = + "CREATE TABLE %s (id INT, salary INT, dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (dep) " + + "TBLPROPERTIES (%s)"; + + sql(createTableStmt, tableName, tablePropsAsString(COMMON_TABLE_PROPERTIES)); + + append(tableName, "{ \"id\": 1, \"salary\": 100, \"dep\": \"hr\" }"); + append(tableName, "{ \"id\": 2, \"salary\": 200, \"dep\": \"hr\" }"); + append(tableName, "{ \"id\": 3, \"salary\": 300, \"dep\": \"hr\" }"); + append(tableName, "{ \"id\": 4, \"salary\": 400, \"dep\": \"hardware\" }"); + + sql(createTableStmt, tableName(OTHER_TABLE_NAME), tablePropsAsString(COMMON_TABLE_PROPERTIES)); + + append(tableName(OTHER_TABLE_NAME), "{ \"id\": 1, \"salary\": 110, \"dep\": \"hr\" }"); + append(tableName(OTHER_TABLE_NAME), "{ \"id\": 5, \"salary\": 500, \"dep\": \"hr\" }"); + + Map updateTableProps = + ImmutableMap.of( + TableProperties.UPDATE_MODE, + mode.modeName(), + TableProperties.UPDATE_DISTRIBUTION_MODE, + "none"); + + sql("ALTER TABLE %s SET TBLPROPERTIES(%s)", tableName, tablePropsAsString(updateTableProps)); + + withSQLConf( + ENABLED_SPJ_SQL_CONF, + () -> { + SparkPlan plan = + executeAndKeepPlan( + "UPDATE %s t SET salary = -1 WHERE " + + "EXISTS (SELECT 1 FROM %s s WHERE t.id = s.id AND t.dep = s.dep)", + tableName, tableName(OTHER_TABLE_NAME)); + String planAsString = plan.toString(); + if (mode == COPY_ON_WRITE) { + int actualNumShuffles = StringUtils.countMatches(planAsString, "Exchange"); + assertThat(actualNumShuffles).as("Should be 1 shuffle with SPJ").isEqualTo(1); + assertThat(planAsString).contains("Exchange hashpartitioning(_file"); + } else { + assertThat(planAsString).doesNotContain("Exchange"); + } + }); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, -1, "hr"), // updated + row(2, 200, "hr"), // existing + row(3, 300, "hr"), // existing + row(4, 400, "hardware")); // existing + + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id, salary", tableName)); + } + + @TestTemplate + public void testCopyOnWriteMergeWithoutShuffles() { + checkMerge(COPY_ON_WRITE, false /* with ON predicate */); + } + + @TestTemplate + public void testCopyOnWriteMergeWithoutShufflesWithPredicate() { + checkMerge(COPY_ON_WRITE, true /* with ON predicate */); + } + + @TestTemplate + public void testMergeOnReadMergeWithoutShuffles() { + checkMerge(MERGE_ON_READ, false /* with ON predicate */); + } + + @TestTemplate + public void testMergeOnReadMergeWithoutShufflesWithPredicate() { + checkMerge(MERGE_ON_READ, true /* with ON predicate */); + } + + private void checkMerge(RowLevelOperationMode mode, boolean withPredicate) { + String createTableStmt = + "CREATE TABLE %s (id INT, salary INT, dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (dep) " + + "TBLPROPERTIES (%s)"; + + sql(createTableStmt, tableName, tablePropsAsString(COMMON_TABLE_PROPERTIES)); + + append(tableName, "{ \"id\": 1, \"salary\": 100, \"dep\": \"hr\" }"); + append(tableName, "{ \"id\": 2, \"salary\": 200, \"dep\": \"hr\" }"); + append(tableName, "{ \"id\": 3, \"salary\": 300, \"dep\": \"hr\" }"); + append(tableName, "{ \"id\": 4, \"salary\": 400, \"dep\": \"hardware\" }"); + append(tableName, "{ \"id\": 6, \"salary\": 600, \"dep\": \"software\" }"); + + sql(createTableStmt, tableName(OTHER_TABLE_NAME), tablePropsAsString(COMMON_TABLE_PROPERTIES)); + + append(tableName(OTHER_TABLE_NAME), "{ \"id\": 1, \"salary\": 110, \"dep\": \"hr\" }"); + append(tableName(OTHER_TABLE_NAME), "{ \"id\": 5, \"salary\": 500, \"dep\": \"hr\" }"); + append(tableName(OTHER_TABLE_NAME), "{ \"id\": 6, \"salary\": 300, \"dep\": \"software\" }"); + append(tableName(OTHER_TABLE_NAME), "{ \"id\": 10, \"salary\": 1000, \"dep\": \"ops\" }"); + + Map mergeTableProps = + ImmutableMap.of( + TableProperties.MERGE_MODE, + mode.modeName(), + TableProperties.MERGE_DISTRIBUTION_MODE, + "none"); + + sql("ALTER TABLE %s SET TBLPROPERTIES(%s)", tableName, tablePropsAsString(mergeTableProps)); + + withSQLConf( + ENABLED_SPJ_SQL_CONF, + () -> { + String predicate = withPredicate ? "AND t.dep IN ('hr', 'ops', 'software')" : ""; + SparkPlan plan = + executeAndKeepPlan( + "MERGE INTO %s AS t USING %s AS s " + + "ON t.id = s.id AND t.dep = s.dep %s " + + "WHEN MATCHED THEN " + + " UPDATE SET t.salary = s.salary " + + "WHEN NOT MATCHED THEN " + + " INSERT *", + tableName, tableName(OTHER_TABLE_NAME), predicate); + String planAsString = plan.toString(); + if (mode == COPY_ON_WRITE) { + int actualNumShuffles = StringUtils.countMatches(planAsString, "Exchange"); + assertThat(actualNumShuffles).as("Should be 1 shuffle with SPJ").isEqualTo(1); + assertThat(planAsString).contains("Exchange hashpartitioning(_file"); + } else { + assertThat(planAsString).doesNotContain("Exchange"); + } + }); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, 110, "hr"), // updated + row(2, 200, "hr"), // existing + row(3, 300, "hr"), // existing + row(4, 400, "hardware"), // existing + row(5, 500, "hr"), // new + row(6, 300, "software"), // updated + row(10, 1000, "ops")); // new + + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id, salary", tableName)); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSystemFunctionPushDownDQL.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSystemFunctionPushDownDQL.java new file mode 100644 index 000000000000..f6102bab69b0 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSystemFunctionPushDownDQL.java @@ -0,0 +1,313 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.expressions.Expressions.bucket; +import static org.apache.iceberg.expressions.Expressions.day; +import static org.apache.iceberg.expressions.Expressions.equal; +import static org.apache.iceberg.expressions.Expressions.greaterThan; +import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual; +import static org.apache.iceberg.expressions.Expressions.hour; +import static org.apache.iceberg.expressions.Expressions.lessThan; +import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual; +import static org.apache.iceberg.expressions.Expressions.month; +import static org.apache.iceberg.expressions.Expressions.notEqual; +import static org.apache.iceberg.expressions.Expressions.truncate; +import static org.apache.iceberg.expressions.Expressions.year; +import static org.apache.iceberg.spark.SystemFunctionPushDownHelper.STRUCT; +import static org.apache.iceberg.spark.SystemFunctionPushDownHelper.createPartitionedTable; +import static org.apache.iceberg.spark.SystemFunctionPushDownHelper.createUnpartitionedTable; +import static org.apache.iceberg.spark.SystemFunctionPushDownHelper.timestampStrToDayOrdinal; +import static org.apache.iceberg.spark.SystemFunctionPushDownHelper.timestampStrToHourOrdinal; +import static org.apache.iceberg.spark.SystemFunctionPushDownHelper.timestampStrToMonthOrdinal; +import static org.apache.iceberg.spark.SystemFunctionPushDownHelper.timestampStrToYearOrdinal; +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.expressions.ExpressionUtil; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.source.PlanUtils; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.expressions.ApplyFunctionExpression; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestSystemFunctionPushDownDQL extends ExtensionsTestBase { + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + public static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.HIVE.catalogName(), + SparkCatalogConfig.HIVE.implementation(), + SparkCatalogConfig.HIVE.properties(), + }, + }; + } + + @BeforeEach + public void before() { + super.before(); + sql("USE %s", catalogName); + } + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testYearsFunctionOnUnpartitionedTable() { + createUnpartitionedTable(spark, tableName); + testYearsFunction(false); + } + + @TestTemplate + public void testYearsFunctionOnPartitionedTable() { + createPartitionedTable(spark, tableName, "years(ts)"); + testYearsFunction(true); + } + + private void testYearsFunction(boolean partitioned) { + int targetYears = timestampStrToYearOrdinal("2017-11-22T00:00:00.000000+00:00"); + String query = + String.format( + "SELECT * FROM %s WHERE system.years(ts) = %s ORDER BY id", tableName, targetYears); + + Dataset df = spark.sql(query); + LogicalPlan optimizedPlan = df.queryExecution().optimizedPlan(); + + checkExpressions(optimizedPlan, partitioned, "years"); + checkPushedFilters(optimizedPlan, equal(year("ts"), targetYears)); + + List actual = rowsToJava(df.collectAsList()); + assertThat(actual).hasSize(5); + } + + @TestTemplate + public void testMonthsFunctionOnUnpartitionedTable() { + createUnpartitionedTable(spark, tableName); + testMonthsFunction(false); + } + + @TestTemplate + public void testMonthsFunctionOnPartitionedTable() { + createPartitionedTable(spark, tableName, "months(ts)"); + testMonthsFunction(true); + } + + private void testMonthsFunction(boolean partitioned) { + int targetMonths = timestampStrToMonthOrdinal("2017-11-22T00:00:00.000000+00:00"); + String query = + String.format( + "SELECT * FROM %s WHERE system.months(ts) > %s ORDER BY id", tableName, targetMonths); + + Dataset df = spark.sql(query); + LogicalPlan optimizedPlan = df.queryExecution().optimizedPlan(); + + checkExpressions(optimizedPlan, partitioned, "months"); + checkPushedFilters(optimizedPlan, greaterThan(month("ts"), targetMonths)); + + List actual = rowsToJava(df.collectAsList()); + assertThat(actual).hasSize(5); + } + + @TestTemplate + public void testDaysFunctionOnUnpartitionedTable() { + createUnpartitionedTable(spark, tableName); + testDaysFunction(false); + } + + @TestTemplate + public void testDaysFunctionOnPartitionedTable() { + createPartitionedTable(spark, tableName, "days(ts)"); + testDaysFunction(true); + } + + private void testDaysFunction(boolean partitioned) { + String timestamp = "2018-11-20T00:00:00.000000+00:00"; + int targetDays = timestampStrToDayOrdinal(timestamp); + String query = + String.format( + "SELECT * FROM %s WHERE system.days(ts) < date('%s') ORDER BY id", + tableName, timestamp); + + Dataset df = spark.sql(query); + LogicalPlan optimizedPlan = df.queryExecution().optimizedPlan(); + + checkExpressions(optimizedPlan, partitioned, "days"); + checkPushedFilters(optimizedPlan, lessThan(day("ts"), targetDays)); + + List actual = rowsToJava(df.collectAsList()); + assertThat(actual).hasSize(5); + } + + @TestTemplate + public void testHoursFunctionOnUnpartitionedTable() { + createUnpartitionedTable(spark, tableName); + testHoursFunction(false); + } + + @TestTemplate + public void testHoursFunctionOnPartitionedTable() { + createPartitionedTable(spark, tableName, "hours(ts)"); + testHoursFunction(true); + } + + private void testHoursFunction(boolean partitioned) { + int targetHours = timestampStrToHourOrdinal("2017-11-22T06:02:09.243857+00:00"); + String query = + String.format( + "SELECT * FROM %s WHERE system.hours(ts) >= %s ORDER BY id", tableName, targetHours); + + Dataset df = spark.sql(query); + LogicalPlan optimizedPlan = df.queryExecution().optimizedPlan(); + + checkExpressions(optimizedPlan, partitioned, "hours"); + checkPushedFilters(optimizedPlan, greaterThanOrEqual(hour("ts"), targetHours)); + + List actual = rowsToJava(df.collectAsList()); + assertThat(actual).hasSize(8); + } + + @TestTemplate + public void testBucketLongFunctionOnUnpartitionedTable() { + createUnpartitionedTable(spark, tableName); + testBucketLongFunction(false); + } + + @TestTemplate + public void testBucketLongFunctionOnPartitionedTable() { + createPartitionedTable(spark, tableName, "bucket(5, id)"); + testBucketLongFunction(true); + } + + private void testBucketLongFunction(boolean partitioned) { + int target = 2; + String query = + String.format( + "SELECT * FROM %s WHERE system.bucket(5, id) <= %s ORDER BY id", tableName, target); + + Dataset df = spark.sql(query); + LogicalPlan optimizedPlan = df.queryExecution().optimizedPlan(); + + checkExpressions(optimizedPlan, partitioned, "bucket"); + checkPushedFilters(optimizedPlan, lessThanOrEqual(bucket("id", 5), target)); + + List actual = rowsToJava(df.collectAsList()); + assertThat(actual).hasSize(5); + } + + @TestTemplate + public void testBucketStringFunctionOnUnpartitionedTable() { + createUnpartitionedTable(spark, tableName); + testBucketStringFunction(false); + } + + @TestTemplate + public void testBucketStringFunctionOnPartitionedTable() { + createPartitionedTable(spark, tableName, "bucket(5, data)"); + testBucketStringFunction(true); + } + + private void testBucketStringFunction(boolean partitioned) { + int target = 2; + String query = + String.format( + "SELECT * FROM %s WHERE system.bucket(5, data) != %s ORDER BY id", tableName, target); + + Dataset df = spark.sql(query); + LogicalPlan optimizedPlan = df.queryExecution().optimizedPlan(); + + checkExpressions(optimizedPlan, partitioned, "bucket"); + checkPushedFilters(optimizedPlan, notEqual(bucket("data", 5), target)); + + List actual = rowsToJava(df.collectAsList()); + assertThat(actual).hasSize(8); + } + + @TestTemplate + public void testTruncateFunctionOnUnpartitionedTable() { + createUnpartitionedTable(spark, tableName); + testTruncateFunction(false); + } + + @TestTemplate + public void testTruncateFunctionOnPartitionedTable() { + createPartitionedTable(spark, tableName, "truncate(4, data)"); + testTruncateFunction(true); + } + + private void testTruncateFunction(boolean partitioned) { + String target = "data"; + String query = + String.format( + "SELECT * FROM %s WHERE system.truncate(4, data) = '%s' ORDER BY id", + tableName, target); + + Dataset df = spark.sql(query); + LogicalPlan optimizedPlan = df.queryExecution().optimizedPlan(); + + checkExpressions(optimizedPlan, partitioned, "truncate"); + checkPushedFilters(optimizedPlan, equal(truncate("data", 4), target)); + + List actual = rowsToJava(df.collectAsList()); + assertThat(actual).hasSize(5); + } + + private void checkExpressions( + LogicalPlan optimizedPlan, boolean partitioned, String expectedFunctionName) { + List staticInvokes = + PlanUtils.collectSparkExpressions( + optimizedPlan, expression -> expression instanceof StaticInvoke); + assertThat(staticInvokes).isEmpty(); + + List applyExpressions = + PlanUtils.collectSparkExpressions( + optimizedPlan, expression -> expression instanceof ApplyFunctionExpression); + + if (partitioned) { + assertThat(applyExpressions).isEmpty(); + } else { + assertThat(applyExpressions).hasSize(1); + ApplyFunctionExpression expression = (ApplyFunctionExpression) applyExpressions.get(0); + assertThat(expression.name()).isEqualTo(expectedFunctionName); + } + } + + private void checkPushedFilters( + LogicalPlan optimizedPlan, org.apache.iceberg.expressions.Expression expected) { + List pushedFilters = + PlanUtils.collectPushDownFilters(optimizedPlan); + assertThat(pushedFilters).hasSize(1); + org.apache.iceberg.expressions.Expression actual = pushedFilters.get(0); + assertThat(ExpressionUtil.equivalent(expected, actual, STRUCT, true)) + .as("Pushed filter should match") + .isTrue(); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSystemFunctionPushDownInRowLevelOperations.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSystemFunctionPushDownInRowLevelOperations.java new file mode 100644 index 000000000000..934220e5d31e --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSystemFunctionPushDownInRowLevelOperations.java @@ -0,0 +1,365 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.RowLevelOperationMode.COPY_ON_WRITE; +import static org.apache.iceberg.RowLevelOperationMode.MERGE_ON_READ; +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import org.apache.iceberg.DistributionMode; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.exceptions.AlreadyExistsException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException; +import org.apache.spark.sql.catalyst.expressions.ApplyFunctionExpression; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke; +import org.apache.spark.sql.execution.CommandResultExec; +import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestSystemFunctionPushDownInRowLevelOperations extends ExtensionsTestBase { + + private static final String CHANGES_TABLE_NAME = "changes"; + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + public static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.HIVE.catalogName(), + SparkCatalogConfig.HIVE.implementation(), + SparkCatalogConfig.HIVE.properties() + } + }; + } + + @BeforeEach + public void beforeEach() { + sql("USE %s", catalogName); + } + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s PURGE", tableName); + sql("DROP TABLE IF EXISTS %s PURGE", tableName(CHANGES_TABLE_NAME)); + } + + @TestTemplate + public void testCopyOnWriteDeleteBucketTransformInPredicate() { + initTable("bucket(4, dep)"); + checkDelete(COPY_ON_WRITE, "system.bucket(4, dep) IN (2, 3)"); + } + + @TestTemplate + public void testMergeOnReadDeleteBucketTransformInPredicate() { + initTable("bucket(4, dep)"); + checkDelete(MERGE_ON_READ, "system.bucket(4, dep) IN (2, 3)"); + } + + @TestTemplate + public void testCopyOnWriteDeleteBucketTransformEqPredicate() { + initTable("bucket(4, dep)"); + checkDelete(COPY_ON_WRITE, "system.bucket(4, dep) = 2"); + } + + @TestTemplate + public void testMergeOnReadDeleteBucketTransformEqPredicate() { + initTable("bucket(4, dep)"); + checkDelete(MERGE_ON_READ, "system.bucket(4, dep) = 2"); + } + + @TestTemplate + public void testCopyOnWriteDeleteYearsTransform() { + initTable("years(ts)"); + checkDelete(COPY_ON_WRITE, "system.years(ts) > 30"); + } + + @TestTemplate + public void testMergeOnReadDeleteYearsTransform() { + initTable("years(ts)"); + checkDelete(MERGE_ON_READ, "system.years(ts) <= 30"); + } + + @TestTemplate + public void testCopyOnWriteDeleteMonthsTransform() { + initTable("months(ts)"); + checkDelete(COPY_ON_WRITE, "system.months(ts) <= 250"); + } + + @TestTemplate + public void testMergeOnReadDeleteMonthsTransform() { + initTable("months(ts)"); + checkDelete(MERGE_ON_READ, "system.months(ts) > 250"); + } + + @TestTemplate + public void testCopyOnWriteDeleteDaysTransform() { + initTable("days(ts)"); + checkDelete(COPY_ON_WRITE, "system.days(ts) <= date('2000-01-03 00:00:00')"); + } + + @TestTemplate + public void testMergeOnReadDeleteDaysTransform() { + initTable("days(ts)"); + checkDelete(MERGE_ON_READ, "system.days(ts) > date('2000-01-03 00:00:00')"); + } + + @TestTemplate + public void testCopyOnWriteDeleteHoursTransform() { + initTable("hours(ts)"); + checkDelete(COPY_ON_WRITE, "system.hours(ts) <= 100000"); + } + + @TestTemplate + public void testMergeOnReadDeleteHoursTransform() { + initTable("hours(ts)"); + checkDelete(MERGE_ON_READ, "system.hours(ts) > 100000"); + } + + @TestTemplate + public void testCopyOnWriteDeleteTruncateTransform() { + initTable("truncate(1, dep)"); + checkDelete(COPY_ON_WRITE, "system.truncate(1, dep) = 'i'"); + } + + @TestTemplate + public void testMergeOnReadDeleteTruncateTransform() { + initTable("truncate(1, dep)"); + checkDelete(MERGE_ON_READ, "system.truncate(1, dep) = 'i'"); + } + + @TestTemplate + public void testCopyOnWriteUpdateBucketTransform() { + initTable("bucket(4, dep)"); + checkUpdate(COPY_ON_WRITE, "system.bucket(4, dep) IN (2, 3)"); + } + + @TestTemplate + public void testMergeOnReadUpdateBucketTransform() { + initTable("bucket(4, dep)"); + checkUpdate(MERGE_ON_READ, "system.bucket(4, dep) = 2"); + } + + @TestTemplate + public void testCopyOnWriteUpdateYearsTransform() { + initTable("years(ts)"); + checkUpdate(COPY_ON_WRITE, "system.years(ts) > 30"); + } + + @TestTemplate + public void testMergeOnReadUpdateYearsTransform() { + initTable("years(ts)"); + checkUpdate(MERGE_ON_READ, "system.years(ts) <= 30"); + } + + @TestTemplate + public void testCopyOnWriteMergeBucketTransform() { + initTable("bucket(4, dep)"); + checkMerge(COPY_ON_WRITE, "system.bucket(4, dep) IN (2, 3)"); + } + + @TestTemplate + public void testMergeOnReadMergeBucketTransform() { + initTable("bucket(4, dep)"); + checkMerge(MERGE_ON_READ, "system.bucket(4, dep) = 2"); + } + + @TestTemplate + public void testCopyOnWriteMergeYearsTransform() { + initTable("years(ts)"); + checkMerge(COPY_ON_WRITE, "system.years(ts) > 30"); + } + + @TestTemplate + public void testMergeOnReadMergeYearsTransform() { + initTable("years(ts)"); + checkMerge(MERGE_ON_READ, "system.years(ts) <= 30"); + } + + @TestTemplate + public void testCopyOnWriteMergeTruncateTransform() { + initTable("truncate(1, dep)"); + checkMerge(COPY_ON_WRITE, "system.truncate(1, dep) = 'i'"); + } + + @TestTemplate + public void testMergeOnReadMergeTruncateTransform() { + initTable("truncate(1, dep)"); + checkMerge(MERGE_ON_READ, "system.truncate(1, dep) = 'i'"); + } + + private void checkDelete(RowLevelOperationMode mode, String cond) { + withUnavailableLocations( + findIrrelevantFileLocations(cond), + () -> { + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' '%s', '%s' '%s')", + tableName, + TableProperties.DELETE_MODE, + mode.modeName(), + TableProperties.DELETE_DISTRIBUTION_MODE, + DistributionMode.NONE.modeName()); + + Dataset changeDF = spark.table(tableName).where(cond).limit(2).select("id"); + try { + changeDF.coalesce(1).writeTo(tableName(CHANGES_TABLE_NAME)).create(); + } catch (TableAlreadyExistsException e) { + throw new AlreadyExistsException( + "Cannot create table %s as it already exists", CHANGES_TABLE_NAME); + } + + List calls = + executeAndCollectFunctionCalls( + "DELETE FROM %s t WHERE %s AND t.id IN (SELECT id FROM %s)", + tableName, cond, tableName(CHANGES_TABLE_NAME)); + // CoW planning currently does not optimize post-scan filters in DELETE + int expectedCallCount = mode == COPY_ON_WRITE ? 1 : 0; + assertThat(calls).hasSize(expectedCallCount); + + assertEquals( + "Should have no matching rows", + ImmutableList.of(), + sql( + "SELECT * FROM %s WHERE %s AND id IN (SELECT * FROM %s)", + tableName, cond, tableName(CHANGES_TABLE_NAME))); + }); + } + + private void checkUpdate(RowLevelOperationMode mode, String cond) { + withUnavailableLocations( + findIrrelevantFileLocations(cond), + () -> { + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' '%s', '%s' '%s')", + tableName, + TableProperties.UPDATE_MODE, + mode.modeName(), + TableProperties.UPDATE_DISTRIBUTION_MODE, + DistributionMode.NONE.modeName()); + + Dataset changeDF = spark.table(tableName).where(cond).limit(2).select("id"); + try { + changeDF.coalesce(1).writeTo(tableName(CHANGES_TABLE_NAME)).create(); + } catch (TableAlreadyExistsException e) { + throw new AlreadyExistsException( + "Cannot create table %s as it already exists", CHANGES_TABLE_NAME); + } + + List calls = + executeAndCollectFunctionCalls( + "UPDATE %s t SET t.salary = -1 WHERE %s AND t.id IN (SELECT id FROM %s)", + tableName, cond, tableName(CHANGES_TABLE_NAME)); + // CoW planning currently does not optimize post-scan filters in UPDATE + int expectedCallCount = mode == COPY_ON_WRITE ? 2 : 0; + assertThat(calls).hasSize(expectedCallCount); + + assertEquals( + "Should have correct updates", + sql("SELECT id FROM %s", tableName(CHANGES_TABLE_NAME)), + sql("SELECT id FROM %s WHERE %s AND salary = -1", tableName, cond)); + }); + } + + private void checkMerge(RowLevelOperationMode mode, String cond) { + withUnavailableLocations( + findIrrelevantFileLocations(cond), + () -> { + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' '%s', '%s' '%s')", + tableName, + TableProperties.MERGE_MODE, + mode.modeName(), + TableProperties.MERGE_DISTRIBUTION_MODE, + DistributionMode.NONE.modeName()); + + Dataset changeDF = + spark.table(tableName).where(cond).limit(2).selectExpr("id + 1 as id"); + try { + changeDF.coalesce(1).writeTo(tableName(CHANGES_TABLE_NAME)).create(); + } catch (TableAlreadyExistsException e) { + throw new AlreadyExistsException( + "Cannot create table %s as it already exists", CHANGES_TABLE_NAME); + } + + List calls = + executeAndCollectFunctionCalls( + "MERGE INTO %s t USING %s s " + + "ON t.id == s.id AND %s " + + "WHEN MATCHED THEN " + + " UPDATE SET salary = -1 " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT (id, salary, dep, ts) VALUES (100, -1, 'hr', null)", + tableName, tableName(CHANGES_TABLE_NAME), cond); + assertThat(calls).isEmpty(); + + assertEquals( + "Should have correct updates", + sql("SELECT id FROM %s", tableName(CHANGES_TABLE_NAME)), + sql("SELECT id FROM %s WHERE %s AND salary = -1", tableName, cond)); + }); + } + + private List executeAndCollectFunctionCalls(String query, Object... args) { + CommandResultExec command = (CommandResultExec) executeAndKeepPlan(query, args); + V2TableWriteExec write = (V2TableWriteExec) command.commandPhysicalPlan(); + return SparkPlanUtil.collectExprs( + write.query(), + expr -> expr instanceof StaticInvoke || expr instanceof ApplyFunctionExpression); + } + + private List findIrrelevantFileLocations(String cond) { + return spark + .table(tableName) + .where("NOT " + cond) + .select(MetadataColumns.FILE_PATH.name()) + .distinct() + .as(Encoders.STRING()) + .collectAsList(); + } + + private void initTable(String transform) { + sql( + "CREATE TABLE %s (id BIGINT, salary INT, dep STRING, ts TIMESTAMP)" + + "USING iceberg " + + "PARTITIONED BY (%s)", + tableName, transform); + + append( + tableName, + "{ \"id\": 1, \"salary\": 100, \"dep\": \"hr\", \"ts\": \"1975-01-01 06:00:00\" }", + "{ \"id\": 2, \"salary\": 200, \"dep\": \"hr\", \"ts\": \"1975-01-01 06:00:00\" }", + "{ \"id\": 3, \"salary\": 300, \"dep\": \"hr\", \"ts\": \"1975-01-01 06:00:00\" }", + "{ \"id\": 4, \"salary\": 400, \"dep\": \"it\", \"ts\": \"2020-01-01 10:00:00\" }", + "{ \"id\": 5, \"salary\": 500, \"dep\": \"it\", \"ts\": \"2020-01-01 10:00:00\" }", + "{ \"id\": 6, \"salary\": 600, \"dep\": \"it\", \"ts\": \"2020-01-01 10:00:00\" }"); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestTagDDL.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestTagDDL.java new file mode 100644 index 000000000000..65c2c0f713cb --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestTagDDL.java @@ -0,0 +1,366 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.List; +import java.util.Locale; +import java.util.concurrent.TimeUnit; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.SnapshotRef; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.parser.extensions.IcebergParseException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestTagDDL extends ExtensionsTestBase { + private static final String[] TIME_UNITS = {"DAYS", "HOURS", "MINUTES"}; + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + public static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.SPARK.catalogName(), + SparkCatalogConfig.SPARK.implementation(), + SparkCatalogConfig.SPARK.properties() + } + }; + } + + @BeforeEach + public void createTable() { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + } + + @AfterEach + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testCreateTagWithRetain() throws NoSuchTableException { + Table table = insertRows(); + long firstSnapshotId = table.currentSnapshot().snapshotId(); + long maxRefAge = 10L; + + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + + for (String timeUnit : TIME_UNITS) { + String tagName = "t1" + timeUnit; + sql( + "ALTER TABLE %s CREATE TAG %s AS OF VERSION %d RETAIN %d %s", + tableName, tagName, firstSnapshotId, maxRefAge, timeUnit); + table.refresh(); + SnapshotRef ref = table.refs().get(tagName); + assertThat(ref.snapshotId()) + .as("The tag needs to point to a specific snapshot id.") + .isEqualTo(firstSnapshotId); + assertThat(ref.maxRefAgeMs().longValue()) + .as("The tag needs to have the correct max ref age.") + .isEqualTo(TimeUnit.valueOf(timeUnit.toUpperCase(Locale.ENGLISH)).toMillis(maxRefAge)); + } + + String tagName = "t1"; + assertThatThrownBy( + () -> + sql( + "ALTER TABLE %s CREATE TAG %s AS OF VERSION %d RETAIN", + tableName, tagName, firstSnapshotId, maxRefAge)) + .isInstanceOf(IcebergParseException.class) + .hasMessageContaining("mismatched input"); + + assertThatThrownBy( + () -> sql("ALTER TABLE %s CREATE TAG %s RETAIN %s DAYS", tableName, tagName, "abc")) + .isInstanceOf(IcebergParseException.class) + .hasMessageContaining("mismatched input"); + + assertThatThrownBy( + () -> + sql( + "ALTER TABLE %s CREATE TAG %s AS OF VERSION %d RETAIN %d SECONDS", + tableName, tagName, firstSnapshotId, maxRefAge)) + .isInstanceOf(IcebergParseException.class) + .hasMessageContaining("mismatched input 'SECONDS' expecting {'DAYS', 'HOURS', 'MINUTES'}"); + } + + @TestTemplate + public void testCreateTagOnEmptyTable() { + assertThatThrownBy(() -> sql("ALTER TABLE %s CREATE TAG %s", tableName, "abc")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining( + "Cannot complete create or replace tag operation on %s, main has no snapshot", + tableName); + } + + @TestTemplate + public void testCreateTagUseDefaultConfig() throws NoSuchTableException { + Table table = insertRows(); + long snapshotId = table.currentSnapshot().snapshotId(); + String tagName = "t1"; + + assertThatThrownBy( + () -> sql("ALTER TABLE %s CREATE TAG %s AS OF VERSION %d", tableName, tagName, -1)) + .isInstanceOf(ValidationException.class) + .hasMessage("Cannot set " + tagName + " to unknown snapshot: -1"); + + sql("ALTER TABLE %s CREATE TAG %s", tableName, tagName); + table.refresh(); + SnapshotRef ref = table.refs().get(tagName); + assertThat(ref.snapshotId()) + .as("The tag needs to point to a specific snapshot id.") + .isEqualTo(snapshotId); + assertThat(ref.maxRefAgeMs()) + .as("The tag needs to have the default max ref age, which is null.") + .isNull(); + + assertThatThrownBy(() -> sql("ALTER TABLE %s CREATE TAG %s", tableName, tagName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("already exists"); + + assertThatThrownBy(() -> sql("ALTER TABLE %s CREATE TAG %s", tableName, "123")) + .isInstanceOf(IcebergParseException.class) + .hasMessageContaining("mismatched input '123'"); + + table.manageSnapshots().removeTag(tagName).commit(); + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + snapshotId = table.currentSnapshot().snapshotId(); + sql("ALTER TABLE %s CREATE TAG %s AS OF VERSION %d", tableName, tagName, snapshotId); + table.refresh(); + ref = table.refs().get(tagName); + assertThat(ref.snapshotId()) + .as("The tag needs to point to a specific snapshot id.") + .isEqualTo(snapshotId); + assertThat(ref.maxRefAgeMs()) + .as("The tag needs to have the default max ref age, which is null.") + .isNull(); + } + + @TestTemplate + public void testCreateTagIfNotExists() throws NoSuchTableException { + long maxSnapshotAge = 2L; + Table table = insertRows(); + String tagName = "t1"; + sql("ALTER TABLE %s CREATE TAG %s RETAIN %d days", tableName, tagName, maxSnapshotAge); + sql("ALTER TABLE %s CREATE TAG IF NOT EXISTS %s", tableName, tagName); + + table.refresh(); + SnapshotRef ref = table.refs().get(tagName); + assertThat(ref.snapshotId()) + .as("The tag needs to point to a specific snapshot id.") + .isEqualTo(table.currentSnapshot().snapshotId()); + assertThat(ref.maxRefAgeMs().longValue()) + .as("The tag needs to have the correct max ref age.") + .isEqualTo(TimeUnit.DAYS.toMillis(maxSnapshotAge)); + } + + @TestTemplate + public void testReplaceTagFailsForBranch() throws NoSuchTableException { + String branchName = "branch1"; + Table table = insertRows(); + long first = table.currentSnapshot().snapshotId(); + table.manageSnapshots().createBranch(branchName, first).commit(); + insertRows(); + long second = table.currentSnapshot().snapshotId(); + + assertThatThrownBy(() -> sql("ALTER TABLE %s REPLACE Tag %s", tableName, branchName, second)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Ref branch1 is a branch not a tag"); + } + + @TestTemplate + public void testReplaceTag() throws NoSuchTableException { + Table table = insertRows(); + long first = table.currentSnapshot().snapshotId(); + String tagName = "t1"; + long expectedMaxRefAgeMs = 1000; + table + .manageSnapshots() + .createTag(tagName, first) + .setMaxRefAgeMs(tagName, expectedMaxRefAgeMs) + .commit(); + + insertRows(); + long second = table.currentSnapshot().snapshotId(); + + sql("ALTER TABLE %s REPLACE Tag %s AS OF VERSION %d", tableName, tagName, second); + table.refresh(); + SnapshotRef ref = table.refs().get(tagName); + assertThat(ref.snapshotId()) + .as("The tag needs to point to a specific snapshot id.") + .isEqualTo(second); + assertThat(ref.maxRefAgeMs().longValue()) + .as("The tag needs to have the correct max ref age.") + .isEqualTo(expectedMaxRefAgeMs); + } + + @TestTemplate + public void testReplaceTagDoesNotExist() throws NoSuchTableException { + Table table = insertRows(); + + assertThatThrownBy( + () -> + sql( + "ALTER TABLE %s REPLACE Tag %s AS OF VERSION %d", + tableName, "someTag", table.currentSnapshot().snapshotId())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Tag does not exist"); + } + + @TestTemplate + public void testReplaceTagWithRetain() throws NoSuchTableException { + Table table = insertRows(); + long first = table.currentSnapshot().snapshotId(); + String tagName = "t1"; + table.manageSnapshots().createTag(tagName, first).commit(); + insertRows(); + long second = table.currentSnapshot().snapshotId(); + + long maxRefAge = 10; + for (String timeUnit : TIME_UNITS) { + sql( + "ALTER TABLE %s REPLACE Tag %s AS OF VERSION %d RETAIN %d %s", + tableName, tagName, second, maxRefAge, timeUnit); + + table.refresh(); + SnapshotRef ref = table.refs().get(tagName); + assertThat(ref.snapshotId()) + .as("The tag needs to point to a specific snapshot id.") + .isEqualTo(second); + assertThat(ref.maxRefAgeMs().longValue()) + .as("The tag needs to have the correct max ref age.") + .isEqualTo(TimeUnit.valueOf(timeUnit).toMillis(maxRefAge)); + } + } + + @TestTemplate + public void testCreateOrReplace() throws NoSuchTableException { + Table table = insertRows(); + long first = table.currentSnapshot().snapshotId(); + String tagName = "t1"; + insertRows(); + long second = table.currentSnapshot().snapshotId(); + table.manageSnapshots().createTag(tagName, second).commit(); + + sql("ALTER TABLE %s CREATE OR REPLACE TAG %s AS OF VERSION %d", tableName, tagName, first); + table.refresh(); + SnapshotRef ref = table.refs().get(tagName); + assertThat(ref.snapshotId()) + .as("The tag needs to point to a specific snapshot id.") + .isEqualTo(first); + } + + @TestTemplate + public void testDropTag() throws NoSuchTableException { + insertRows(); + Table table = validationCatalog.loadTable(tableIdent); + String tagName = "t1"; + table.manageSnapshots().createTag(tagName, table.currentSnapshot().snapshotId()).commit(); + SnapshotRef ref = table.refs().get(tagName); + assertThat(ref.snapshotId()).as("").isEqualTo(table.currentSnapshot().snapshotId()); + + sql("ALTER TABLE %s DROP TAG %s", tableName, tagName); + table.refresh(); + ref = table.refs().get(tagName); + assertThat(ref).as("The tag needs to be dropped.").isNull(); + } + + @TestTemplate + public void testDropTagNonConformingName() { + assertThatThrownBy(() -> sql("ALTER TABLE %s DROP TAG %s", tableName, "123")) + .isInstanceOf(IcebergParseException.class) + .hasMessageContaining("mismatched input '123'"); + } + + @TestTemplate + public void testDropTagDoesNotExist() { + assertThatThrownBy(() -> sql("ALTER TABLE %s DROP TAG %s", tableName, "nonExistingTag")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Tag does not exist: nonExistingTag"); + } + + @TestTemplate + public void testDropTagFailesForBranch() throws NoSuchTableException { + String branchName = "b1"; + Table table = insertRows(); + table.manageSnapshots().createBranch(branchName, table.currentSnapshot().snapshotId()).commit(); + + assertThatThrownBy(() -> sql("ALTER TABLE %s DROP TAG %s", tableName, branchName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Ref b1 is a branch not a tag"); + } + + @TestTemplate + public void testDropTagIfExists() throws NoSuchTableException { + String tagName = "nonExistingTag"; + Table table = insertRows(); + assertThat(table.refs().get(tagName)).as("The tag does not exists.").isNull(); + + sql("ALTER TABLE %s DROP TAG IF EXISTS %s", tableName, tagName); + table.refresh(); + assertThat(table.refs().get(tagName)).as("The tag still does not exist.").isNull(); + + table.manageSnapshots().createTag(tagName, table.currentSnapshot().snapshotId()).commit(); + assertThat(table.refs().get(tagName).snapshotId()) + .as("The tag has been created successfully.") + .isEqualTo(table.currentSnapshot().snapshotId()); + + sql("ALTER TABLE %s DROP TAG IF EXISTS %s", tableName, tagName); + table.refresh(); + assertThat(table.refs().get(tagName)).as("The tag needs to be dropped.").isNull(); + } + + @TestTemplate + public void createOrReplaceWithNonExistingTag() throws NoSuchTableException { + Table table = insertRows(); + String tagName = "t1"; + insertRows(); + long snapshotId = table.currentSnapshot().snapshotId(); + + sql("ALTER TABLE %s CREATE OR REPLACE TAG %s AS OF VERSION %d", tableName, tagName, snapshotId); + table.refresh(); + assertThat(table.refs().get(tagName).snapshotId()).isEqualTo(snapshotId); + } + + private Table insertRows() throws NoSuchTableException { + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + return validationCatalog.loadTable(tableIdent); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestUpdate.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestUpdate.java new file mode 100644 index 000000000000..550bf41ce220 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestUpdate.java @@ -0,0 +1,1508 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.DataOperations.OVERWRITE; +import static org.apache.iceberg.RowLevelOperationMode.COPY_ON_WRITE; +import static org.apache.iceberg.SnapshotSummary.ADDED_FILES_PROP; +import static org.apache.iceberg.SnapshotSummary.CHANGED_PARTITION_COUNT_PROP; +import static org.apache.iceberg.SnapshotSummary.DELETED_FILES_PROP; +import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES; +import static org.apache.iceberg.TableProperties.SPLIT_OPEN_FILE_COST; +import static org.apache.iceberg.TableProperties.SPLIT_SIZE; +import static org.apache.iceberg.TableProperties.UPDATE_DISTRIBUTION_MODE; +import static org.apache.iceberg.TableProperties.UPDATE_ISOLATION_LEVEL; +import static org.apache.iceberg.TableProperties.UPDATE_MODE; +import static org.apache.iceberg.TableProperties.UPDATE_MODE_DEFAULT; +import static org.apache.spark.sql.functions.lit; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assumptions.assumeThat; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.iceberg.AppendFiles; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DistributionMode; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.SparkException; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.execution.SparkPlan; +import org.apache.spark.sql.internal.SQLConf; +import org.awaitility.Awaitility; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public abstract class TestUpdate extends SparkRowLevelOperationsTestBase { + + @BeforeAll + public static void setupSparkConf() { + spark.conf().set("spark.sql.shuffle.partitions", "4"); + } + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS updated_id"); + sql("DROP TABLE IF EXISTS updated_dep"); + sql("DROP TABLE IF EXISTS deleted_employee"); + } + + @TestTemplate + public void testUpdateWithVectorizedReads() { + assumeThat(supportsVectorization()).isTrue(); + + createAndInitTable( + "id INT, value INT, dep STRING", + "PARTITIONED BY (dep)", + "{ \"id\": 1, \"value\": 100, \"dep\": \"hr\" }"); + + SparkPlan plan = executeAndKeepPlan("UPDATE %s SET value = -1 WHERE id = 1", commitTarget()); + + assertAllBatchScansVectorized(plan); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, -1, "hr")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testCoalesceUpdate() { + createAndInitTable("id INT, dep STRING"); + + String[] records = new String[100]; + for (int index = 0; index < 100; index++) { + records[index] = String.format("{ \"id\": %d, \"dep\": \"hr\" }", index); + } + append(tableName, records); + append(tableName, records); + append(tableName, records); + append(tableName, records); + + // set the open file cost large enough to produce a separate scan task per file + // use range distribution to trigger a shuffle + Map tableProps = + ImmutableMap.of( + SPLIT_OPEN_FILE_COST, + String.valueOf(Integer.MAX_VALUE), + UPDATE_DISTRIBUTION_MODE, + DistributionMode.RANGE.modeName()); + sql("ALTER TABLE %s SET TBLPROPERTIES (%s)", tableName, tablePropsAsString(tableProps)); + + createBranchIfNeeded(); + + // enable AQE and set the advisory partition size big enough to trigger combining + // set the number of shuffle partitions to 200 to distribute the work across reducers + // set the advisory partition size for shuffles small enough to ensure writes override it + withSQLConf( + ImmutableMap.of( + SQLConf.SHUFFLE_PARTITIONS().key(), + "200", + SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), + "true", + SQLConf.COALESCE_PARTITIONS_ENABLED().key(), + "true", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES().key(), + "100", + SparkSQLProperties.ADVISORY_PARTITION_SIZE, + String.valueOf(256 * 1024 * 1024)), + () -> { + SparkPlan plan = + executeAndKeepPlan("UPDATE %s SET id = -1 WHERE mod(id, 2) = 0", commitTarget()); + assertThat(plan.toString()).contains("REBALANCE_PARTITIONS_BY_COL"); + }); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot snapshot = SnapshotUtil.latestSnapshot(table, branch); + + if (mode(table) == COPY_ON_WRITE) { + // CoW UPDATE requests the updated records to be range distributed by `_file`, `_pos` + // every task has data for each of 200 reducers + // AQE detects that all shuffle blocks are small and processes them in 1 task + // otherwise, there would be 200 tasks writing to the table + validateProperty(snapshot, SnapshotSummary.ADDED_FILES_PROP, "1"); + } else { + // MoR UPDATE requests the deleted records to be range distributed by partition and `_file` + // each task contains only 1 file and therefore writes only 1 shuffle block + // that means 4 shuffle blocks are distributed among 200 reducers + // AQE detects that all 4 shuffle blocks are small and processes them in 1 task + // otherwise, there would be 4 tasks processing 1 shuffle block each + validateProperty(snapshot, SnapshotSummary.ADDED_DELETE_FILES_PROP, "1"); + } + + assertThat(scalarSql("SELECT COUNT(*) FROM %s WHERE id = -1", commitTarget())) + .as("Row count must match") + .isEqualTo(200L); + } + + @TestTemplate + public void testSkewUpdate() { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + String[] records = new String[100]; + for (int index = 0; index < 100; index++) { + records[index] = String.format("{ \"id\": %d, \"dep\": \"hr\" }", index); + } + append(tableName, records); + append(tableName, records); + append(tableName, records); + append(tableName, records); + + // set the open file cost large enough to produce a separate scan task per file + // use hash distribution to trigger a shuffle + Map tableProps = + ImmutableMap.of( + SPLIT_OPEN_FILE_COST, + String.valueOf(Integer.MAX_VALUE), + UPDATE_DISTRIBUTION_MODE, + DistributionMode.HASH.modeName()); + sql("ALTER TABLE %s SET TBLPROPERTIES (%s)", tableName, tablePropsAsString(tableProps)); + + createBranchIfNeeded(); + + // enable AQE and set the advisory partition size small enough to trigger a split + // set the number of shuffle partitions to 2 to only have 2 reducers + // set the advisory partition size for shuffles big enough to ensure writes override it + withSQLConf( + ImmutableMap.of( + SQLConf.SHUFFLE_PARTITIONS().key(), + "2", + SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), + "true", + SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED().key(), + "true", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES().key(), + "256MB", + SparkSQLProperties.ADVISORY_PARTITION_SIZE, + "100"), + () -> { + SparkPlan plan = + executeAndKeepPlan("UPDATE %s SET id = -1 WHERE mod(id, 2) = 0", commitTarget()); + assertThat(plan.toString()).contains("REBALANCE_PARTITIONS_BY_COL"); + }); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot snapshot = SnapshotUtil.latestSnapshot(table, branch); + + if (mode(table) == COPY_ON_WRITE) { + // CoW UPDATE requests the updated records to be clustered by `_file` + // each task contains only 1 file and therefore writes only 1 shuffle block + // that means 4 shuffle blocks are distributed among 2 reducers + // AQE detects that all shuffle blocks are big and processes them in 4 independent tasks + // otherwise, there would be 2 tasks processing 2 shuffle blocks each + validateProperty(snapshot, SnapshotSummary.ADDED_FILES_PROP, "4"); + } else { + // MoR UPDATE requests the deleted records to be clustered by `_spec_id` and `_partition` + // all tasks belong to the same partition and therefore write only 1 shuffle block per task + // that means there are 4 shuffle blocks, all assigned to the same reducer + // AQE detects that all 4 shuffle blocks are big and processes them in 4 separate tasks + // otherwise, there would be 1 task processing 4 shuffle blocks + validateProperty(snapshot, SnapshotSummary.ADDED_DELETE_FILES_PROP, "4"); + } + + assertThat(scalarSql("SELECT COUNT(*) FROM %s WHERE id = -1", commitTarget())) + .as("Row count must match") + .isEqualTo(200L); + } + + @TestTemplate + public void testExplain() { + createAndInitTable("id INT, dep STRING"); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + createBranchIfNeeded(); + + sql("EXPLAIN UPDATE %s SET dep = 'invalid' WHERE id <=> 1", commitTarget()); + + sql("EXPLAIN UPDATE %s SET dep = 'invalid' WHERE true", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).as("Should have 1 snapshot").hasSize(1); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @TestTemplate + public void testUpdateEmptyTable() { + assumeThat(branch).as("Custom branch does not exist for empty table").isNotEqualTo("test"); + createAndInitTable("id INT, dep STRING"); + + sql("UPDATE %s SET dep = 'invalid' WHERE id IN (1)", commitTarget()); + sql("UPDATE %s SET id = -1 WHERE dep = 'hr'", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).as("Should have 2 snapshots").hasSize(2); + + assertEquals( + "Should have expected rows", + ImmutableList.of(), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testUpdateNonExistingCustomBranch() { + assumeThat(branch).as("Test only applicable to custom branch").isEqualTo("test"); + createAndInitTable("id INT, dep STRING"); + + assertThatThrownBy(() -> sql("UPDATE %s SET dep = 'invalid' WHERE id IN (1)", commitTarget())) + .isInstanceOf(ValidationException.class) + .hasMessage("Cannot use branch (does not exist): test"); + } + + @TestTemplate + public void testUpdateWithAlias() { + createAndInitTable("id INT, dep STRING", "{ \"id\": 1, \"dep\": \"a\" }"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + sql("UPDATE %s AS t SET t.dep = 'invalid'", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).as("Should have 2 snapshots").hasSize(2); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "invalid")), + sql("SELECT * FROM %s", selectTarget())); + } + + @TestTemplate + public void testUpdateAlignsAssignments() { + createAndInitTable("id INT, c1 INT, c2 INT"); + + sql("INSERT INTO TABLE %s VALUES (1, 11, 111), (2, 22, 222)", tableName); + createBranchIfNeeded(); + + sql("UPDATE %s SET `c2` = c2 - 2, c1 = `c1` - 1 WHERE id <=> 1", commitTarget()); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, 10, 109), row(2, 22, 222)), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testUpdateWithUnsupportedPartitionPredicate() { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'software'), (2, 'hr')", tableName); + createBranchIfNeeded(); + + sql("UPDATE %s t SET `t`.`id` = -1 WHERE t.dep LIKE '%%r' ", commitTarget()); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(1, "software")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testUpdateWithDynamicFileFiltering() { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + append(tableName, "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 3, \"dep\": \"hr\" }"); + createBranchIfNeeded(); + append( + commitTarget(), + "{ \"id\": 1, \"dep\": \"hardware\" }\n" + "{ \"id\": 2, \"dep\": \"hardware\" }"); + + sql("UPDATE %s SET id = cast('-1' AS INT) WHERE id = 2", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).as("Should have 3 snapshots").hasSize(3); + + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + if (mode(table) == COPY_ON_WRITE) { + validateCopyOnWrite(currentSnapshot, "1", "1", "1"); + } else { + validateMergeOnRead(currentSnapshot, "1", "1", "1"); + } + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hardware"), row(1, "hardware"), row(1, "hr"), row(3, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", commitTarget())); + } + + @TestTemplate + public void testUpdateNonExistingRecords() { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + createBranchIfNeeded(); + + sql("UPDATE %s SET id = -1 WHERE id > 10", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).as("Should have 2 snapshots").hasSize(2); + + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + if (mode(table) == COPY_ON_WRITE) { + validateCopyOnWrite(currentSnapshot, "0", null, null); + } else { + validateMergeOnRead(currentSnapshot, "0", null, null); + } + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @TestTemplate + public void testUpdateWithoutCondition() { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); + sql("INSERT INTO TABLE %s VALUES (2, 'hardware')", commitTarget()); + sql("INSERT INTO TABLE %s VALUES (null, 'hr')", commitTarget()); + + // set the num of shuffle partitions to 200 instead of default 4 to reduce the chance of hashing + // records for multiple source files to one writing task (needed for a predictable num of output + // files) + withSQLConf( + ImmutableMap.of(SQLConf.SHUFFLE_PARTITIONS().key(), "200"), + () -> { + sql("UPDATE %s SET id = -1", commitTarget()); + }); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).as("Should have 4 snapshots").hasSize(4); + + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + + assertThat(currentSnapshot.operation()).as("Operation must match").isEqualTo(OVERWRITE); + if (mode(table) == COPY_ON_WRITE) { + assertThat(currentSnapshot.operation()).as("Operation must match").isEqualTo(OVERWRITE); + validateProperty(currentSnapshot, CHANGED_PARTITION_COUNT_PROP, "2"); + validateProperty(currentSnapshot, DELETED_FILES_PROP, "3"); + validateProperty(currentSnapshot, ADDED_FILES_PROP, ImmutableSet.of("2", "3")); + } else { + validateMergeOnRead(currentSnapshot, "2", "2", "2"); + } + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hardware"), row(-1, "hr"), row(-1, "hr")), + sql("SELECT * FROM %s ORDER BY dep ASC", selectTarget())); + } + + @TestTemplate + public void testUpdateWithNullConditions() { + createAndInitTable("id INT, dep STRING"); + + append( + tableName, + "{ \"id\": 0, \"dep\": null }\n" + + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hardware\" }"); + createBranchIfNeeded(); + + // should not update any rows as null is never equal to null + sql("UPDATE %s SET id = -1 WHERE dep = NULL", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(0, null), row(1, "hr"), row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + + // should not update any rows the condition does not match any records + sql("UPDATE %s SET id = -1 WHERE dep = 'software'", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(0, null), row(1, "hr"), row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + + // should update one matching row with a null-safe condition + sql("UPDATE %s SET dep = 'invalid', id = -1 WHERE dep <=> NULL", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "invalid"), row(1, "hr"), row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testUpdateWithInAndNotInConditions() { + createAndInitTable("id INT, dep STRING"); + + append( + tableName, + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hardware\" }\n" + + "{ \"id\": null, \"dep\": \"hr\" }"); + createBranchIfNeeded(); + + sql("UPDATE %s SET id = -1 WHERE id IN (1, null)", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql("UPDATE %s SET id = 100 WHERE id NOT IN (null, 1)", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql("UPDATE %s SET id = 100 WHERE id NOT IN (1, 10)", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(100, "hardware"), row(100, "hr"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", selectTarget())); + } + + @TestTemplate + public void testUpdateWithMultipleRowGroupsParquet() throws NoSuchTableException { + assumeThat(fileFormat).isEqualTo(FileFormat.PARQUET); + + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", + tableName, PARQUET_ROW_GROUP_SIZE_BYTES, 100); + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", tableName, SPLIT_SIZE, 100); + + List ids = Lists.newArrayListWithCapacity(200); + for (int id = 1; id <= 200; id++) { + ids.add(id); + } + Dataset df = + spark + .createDataset(ids, Encoders.INT()) + .withColumnRenamed("value", "id") + .withColumn("dep", lit("hr")); + df.coalesce(1).writeTo(tableName).append(); + createBranchIfNeeded(); + + assertThat(spark.table(commitTarget()).count()).isEqualTo(200); + + // update a record from one of two row groups and copy over the second one + sql("UPDATE %s SET id = -1 WHERE id IN (200, 201)", commitTarget()); + + assertThat(spark.table(commitTarget()).count()).isEqualTo(200); + } + + @TestTemplate + public void testUpdateNestedStructFields() { + createAndInitTable( + "id INT, s STRUCT,m:MAP>>", + "{ \"id\": 1, \"s\": { \"c1\": 2, \"c2\": { \"a\": [1,2], \"m\": { \"a\": \"b\"} } } } }"); + + // update primitive, array, map columns inside a struct + sql("UPDATE %s SET s.c1 = -1, s.c2.m = map('k', 'v'), s.c2.a = array(-1)", commitTarget()); + + assertEquals( + "Output should match", + ImmutableList.of(row(1, row(-1, row(ImmutableList.of(-1), ImmutableMap.of("k", "v"))))), + sql("SELECT * FROM %s", selectTarget())); + + // set primitive, array, map columns to NULL (proper casts should be in place) + sql("UPDATE %s SET s.c1 = NULL, s.c2 = NULL WHERE id IN (1)", commitTarget()); + + assertEquals( + "Output should match", + ImmutableList.of(row(1, row(null, null))), + sql("SELECT * FROM %s", selectTarget())); + + // update all fields in a struct + sql( + "UPDATE %s SET s = named_struct('c1', 1, 'c2', named_struct('a', array(1), 'm', null))", + commitTarget()); + + assertEquals( + "Output should match", + ImmutableList.of(row(1, row(1, row(ImmutableList.of(1), null)))), + sql("SELECT * FROM %s", selectTarget())); + } + + @TestTemplate + public void testUpdateWithUserDefinedDistribution() { + createAndInitTable("id INT, c2 INT, c3 INT"); + sql("ALTER TABLE %s ADD PARTITION FIELD bucket(8, c3)", tableName); + + append( + tableName, + "{ \"id\": 1, \"c2\": 11, \"c3\": 1 }\n" + + "{ \"id\": 2, \"c2\": 22, \"c3\": 1 }\n" + + "{ \"id\": 3, \"c2\": 33, \"c3\": 1 }"); + createBranchIfNeeded(); + + // request a global sort + sql("ALTER TABLE %s WRITE ORDERED BY c2", tableName); + sql("UPDATE %s SET c2 = -22 WHERE id NOT IN (1, 3)", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, 11, 1), row(2, -22, 1), row(3, 33, 1)), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + // request a local sort + sql("ALTER TABLE %s WRITE LOCALLY ORDERED BY id", tableName); + sql("UPDATE %s SET c2 = -33 WHERE id = 3", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, 11, 1), row(2, -22, 1), row(3, -33, 1)), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + // request a hash distribution + local sort + sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION ORDERED BY id", tableName); + sql("UPDATE %s SET c2 = -11 WHERE id = 1", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, -11, 1), row(2, -22, 1), row(3, -33, 1)), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @TestTemplate + public synchronized void testUpdateWithSerializableIsolation() throws InterruptedException { + // cannot run tests with concurrency for Hadoop tables without atomic renames + assumeThat(catalogName).isNotEqualToIgnoringCase("testhadoop"); + // if caching is off, the table is eagerly refreshed during runtime filtering + // this can cause a validation exception as concurrent changes would be visible + assumeThat(cachingCatalogEnabled()).isTrue(); + + createAndInitTable("id INT, dep STRING"); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, UPDATE_ISOLATION_LEVEL, "serializable"); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); + + ExecutorService executorService = + MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) Executors.newFixedThreadPool(2)); + + AtomicInteger barrier = new AtomicInteger(0); + AtomicBoolean shouldAppend = new AtomicBoolean(true); + + // update thread + Future updateFuture = + executorService.submit( + () -> { + for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) { + int currentNumOperations = numOperations; + Awaitility.await() + .pollInterval(10, TimeUnit.MILLISECONDS) + .atMost(5, TimeUnit.SECONDS) + .until(() -> barrier.get() >= currentNumOperations * 2); + + sql("UPDATE %s SET id = -1 WHERE id = 1", commitTarget()); + + barrier.incrementAndGet(); + } + }); + + // append thread + Future appendFuture = + executorService.submit( + () -> { + // load the table via the validation catalog to use another table instance + Table table = validationCatalog.loadTable(tableIdent); + + GenericRecord record = GenericRecord.create(SnapshotUtil.schemaFor(table, branch)); + record.set(0, 1); // id + record.set(1, "hr"); // dep + + for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) { + int currentNumOperations = numOperations; + Awaitility.await() + .pollInterval(10, TimeUnit.MILLISECONDS) + .atMost(5, TimeUnit.SECONDS) + .until(() -> !shouldAppend.get() || barrier.get() >= currentNumOperations * 2); + + if (!shouldAppend.get()) { + return; + } + + for (int numAppends = 0; numAppends < 5; numAppends++) { + DataFile dataFile = writeDataFile(table, ImmutableList.of(record)); + AppendFiles appendFiles = table.newFastAppend().appendFile(dataFile); + if (branch != null) { + appendFiles.toBranch(branch); + } + + appendFiles.commit(); + } + + barrier.incrementAndGet(); + } + }); + + try { + assertThatThrownBy(updateFuture::get) + .isInstanceOf(ExecutionException.class) + .cause() + .isInstanceOf(ValidationException.class) + .hasMessageContaining("Found conflicting files that can contain"); + } finally { + shouldAppend.set(false); + appendFuture.cancel(true); + } + + executorService.shutdown(); + assertThat(executorService.awaitTermination(2, TimeUnit.MINUTES)).as("Timeout").isTrue(); + } + + @TestTemplate + public synchronized void testUpdateWithSnapshotIsolation() + throws InterruptedException, ExecutionException { + // cannot run tests with concurrency for Hadoop tables without atomic renames + assumeThat(catalogName).isNotEqualToIgnoringCase("testhadoop"); + // if caching is off, the table is eagerly refreshed during runtime filtering + // this can cause a validation exception as concurrent changes would be visible + assumeThat(cachingCatalogEnabled()).isTrue(); + + createAndInitTable("id INT, dep STRING"); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, UPDATE_ISOLATION_LEVEL, "snapshot"); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); + + ExecutorService executorService = + MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) Executors.newFixedThreadPool(2)); + + AtomicInteger barrier = new AtomicInteger(0); + AtomicBoolean shouldAppend = new AtomicBoolean(true); + + // update thread + Future updateFuture = + executorService.submit( + () -> { + for (int numOperations = 0; numOperations < 20; numOperations++) { + int currentNumOperations = numOperations; + Awaitility.await() + .pollInterval(10, TimeUnit.MILLISECONDS) + .atMost(5, TimeUnit.SECONDS) + .until(() -> barrier.get() >= currentNumOperations * 2); + + sql("UPDATE %s SET id = -1 WHERE id = 1", tableName); + + barrier.incrementAndGet(); + } + }); + + // append thread + Future appendFuture = + executorService.submit( + () -> { + // load the table via the validation catalog to use another table instance for inserts + Table table = validationCatalog.loadTable(tableIdent); + + GenericRecord record = GenericRecord.create(SnapshotUtil.schemaFor(table, branch)); + record.set(0, 1); // id + record.set(1, "hr"); // dep + + for (int numOperations = 0; numOperations < 20; numOperations++) { + int currentNumOperations = numOperations; + Awaitility.await() + .pollInterval(10, TimeUnit.MILLISECONDS) + .atMost(5, TimeUnit.SECONDS) + .until(() -> !shouldAppend.get() || barrier.get() >= currentNumOperations * 2); + + if (!shouldAppend.get()) { + return; + } + + for (int numAppends = 0; numAppends < 5; numAppends++) { + DataFile dataFile = writeDataFile(table, ImmutableList.of(record)); + AppendFiles appendFiles = table.newFastAppend().appendFile(dataFile); + if (branch != null) { + appendFiles.toBranch(branch); + } + + appendFiles.commit(); + } + + barrier.incrementAndGet(); + } + }); + + try { + updateFuture.get(); + } finally { + shouldAppend.set(false); + appendFuture.cancel(true); + } + + executorService.shutdown(); + assertThat(executorService.awaitTermination(2, TimeUnit.MINUTES)).as("Timeout").isTrue(); + } + + @TestTemplate + public void testUpdateWithInferredCasts() { + createAndInitTable("id INT, s STRING", "{ \"id\": 1, \"s\": \"value\" }"); + + sql("UPDATE %s SET s = -1 WHERE id = 1", commitTarget()); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "-1")), + sql("SELECT * FROM %s", selectTarget())); + } + + @TestTemplate + public void testUpdateModifiesNullStruct() { + createAndInitTable("id INT, s STRUCT", "{ \"id\": 1, \"s\": null }"); + + sql("UPDATE %s SET s.n1 = -1 WHERE id = 1", commitTarget()); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, row(-1, null))), + sql("SELECT * FROM %s", selectTarget())); + } + + @TestTemplate + public void testUpdateRefreshesRelationCache() { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + append(tableName, "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 3, \"dep\": \"hr\" }"); + createBranchIfNeeded(); + + append( + commitTarget(), + "{ \"id\": 1, \"dep\": \"hardware\" }\n" + "{ \"id\": 2, \"dep\": \"hardware\" }"); + + Dataset query = spark.sql("SELECT * FROM " + commitTarget() + " WHERE id = 1"); + query.createOrReplaceTempView("tmp"); + + spark.sql("CACHE TABLE tmp"); + + assertEquals( + "View should have correct data", + ImmutableList.of(row(1, "hardware"), row(1, "hr")), + sql("SELECT * FROM tmp ORDER BY id, dep")); + + sql("UPDATE %s SET id = -1 WHERE id = 1", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).as("Should have 3 snapshots").hasSize(3); + + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + if (mode(table) == COPY_ON_WRITE) { + validateCopyOnWrite(currentSnapshot, "2", "2", "2"); + } else { + validateMergeOnRead(currentSnapshot, "2", "2", "2"); + } + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hardware"), row(-1, "hr"), row(2, "hardware"), row(3, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", commitTarget())); + + assertEquals( + "Should refresh the relation cache", + ImmutableList.of(), + sql("SELECT * FROM tmp ORDER BY id, dep")); + + spark.sql("UNCACHE TABLE tmp"); + } + + @TestTemplate + public void testUpdateWithInSubquery() { + createAndInitTable("id INT, dep STRING"); + + append( + tableName, + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hardware\" }\n" + + "{ \"id\": null, \"dep\": \"hr\" }"); + createBranchIfNeeded(); + + createOrReplaceView("updated_id", Arrays.asList(0, 1, null), Encoders.INT()); + createOrReplaceView("updated_dep", Arrays.asList("software", "hr"), Encoders.STRING()); + + sql( + "UPDATE %s SET id = -1 WHERE " + + "id IN (SELECT * FROM updated_id) AND " + + "dep IN (SELECT * from updated_dep)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql( + "UPDATE %s SET id = 5 WHERE id IS NULL OR id IN (SELECT value + 1 FROM updated_id)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(5, "hardware"), row(5, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); + + append( + commitTarget(), "{ \"id\": null, \"dep\": \"hr\" }\n" + "{ \"id\": 2, \"dep\": \"hr\" }"); + assertEquals( + "Should have expected rows", + ImmutableList.of( + row(-1, "hr"), row(2, "hr"), row(5, "hardware"), row(5, "hr"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", selectTarget())); + + sql( + "UPDATE %s SET id = 10 WHERE id IN (SELECT value + 2 FROM updated_id) AND dep = 'hr'", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of( + row(-1, "hr"), row(5, "hardware"), row(5, "hr"), row(10, "hr"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", selectTarget())); + } + + @TestTemplate + public void testUpdateWithInSubqueryAndDynamicFileFiltering() { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION", tableName); + + append(tableName, "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 3, \"dep\": \"hr\" }"); + createBranchIfNeeded(); + append( + commitTarget(), + "{ \"id\": 1, \"dep\": \"hardware\" }\n" + "{ \"id\": 2, \"dep\": \"hardware\" }"); + + createOrReplaceView("updated_id", Arrays.asList(-1, 2), Encoders.INT()); + + sql("UPDATE %s SET id = -1 WHERE id IN (SELECT * FROM updated_id)", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).as("Should have 3 snapshots").hasSize(3); + + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + if (mode(table) == COPY_ON_WRITE) { + validateCopyOnWrite(currentSnapshot, "1", "1", "1"); + } else { + validateMergeOnRead(currentSnapshot, "1", "1", "1"); + } + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hardware"), row(1, "hardware"), row(1, "hr"), row(3, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", commitTarget())); + } + + @TestTemplate + public void testUpdateWithSelfSubquery() { + createAndInitTable("id INT, dep STRING"); + + append(tableName, "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 2, \"dep\": \"hr\" }"); + createBranchIfNeeded(); + + sql( + "UPDATE %s SET dep = 'x' WHERE id IN (SELECT id + 1 FROM %s)", + commitTarget(), commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "x")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + + // TODO: Spark does not support AQE and DPP with aggregates at the moment + withSQLConf( + ImmutableMap.of(SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), "false"), + () -> { + sql( + "UPDATE %s SET dep = 'y' WHERE " + + "id = (SELECT count(*) FROM (SELECT DISTINCT id FROM %s) AS t)", + commitTarget(), commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "y")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + }); + + sql("UPDATE %s SET id = (SELECT id - 2 FROM %s WHERE id = 1)", commitTarget(), commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(-1, "y")), + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); + } + + @TestTemplate + public void testUpdateWithMultiColumnInSubquery() { + createAndInitTable("id INT, dep STRING"); + + append( + tableName, + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hardware\" }\n" + + "{ \"id\": null, \"dep\": \"hr\" }"); + createBranchIfNeeded(); + + List deletedEmployees = + Arrays.asList(new Employee(null, "hr"), new Employee(1, "hr")); + createOrReplaceView("deleted_employee", deletedEmployees, Encoders.bean(Employee.class)); + + sql( + "UPDATE %s SET dep = 'x', id = -1 WHERE (id, dep) IN (SELECT id, dep FROM deleted_employee)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "x"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @TestTemplate + public void testUpdateWithNotInSubquery() { + createAndInitTable("id INT, dep STRING"); + + append( + tableName, + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hardware\" }\n" + + "{ \"id\": null, \"dep\": \"hr\" }"); + createBranchIfNeeded(); + + createOrReplaceView("updated_id", Arrays.asList(-1, -2, null), Encoders.INT()); + createOrReplaceView("updated_dep", Arrays.asList("software", "hr"), Encoders.STRING()); + + // the file filter subquery (nested loop lef-anti join) returns 0 records + sql("UPDATE %s SET id = -1 WHERE id NOT IN (SELECT * FROM updated_id)", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql( + "UPDATE %s SET id = -1 WHERE id NOT IN (SELECT * FROM updated_id WHERE value IS NOT NULL)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hardware"), row(-1, "hr"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", selectTarget())); + + sql( + "UPDATE %s SET id = 5 WHERE id NOT IN (SELECT * FROM updated_id) OR dep IN ('software', 'hr')", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hardware"), row(5, "hr"), row(5, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", selectTarget())); + } + + @TestTemplate + public void testUpdateWithExistSubquery() { + createAndInitTable("id INT, dep STRING"); + + append( + tableName, + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hardware\" }\n" + + "{ \"id\": null, \"dep\": \"hr\" }"); + createBranchIfNeeded(); + + createOrReplaceView("updated_id", Arrays.asList(-1, -2, null), Encoders.INT()); + createOrReplaceView("updated_dep", Arrays.asList("hr", null), Encoders.STRING()); + + sql( + "UPDATE %s t SET id = -1 WHERE EXISTS (SELECT 1 FROM updated_id u WHERE t.id = u.value)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql( + "UPDATE %s t SET dep = 'x', id = -1 WHERE " + + "EXISTS (SELECT 1 FROM updated_id u WHERE t.id = u.value + 2)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "x"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql( + "UPDATE %s t SET id = -2 WHERE " + + "EXISTS (SELECT 1 FROM updated_id u WHERE t.id = u.value) OR " + + "t.id IS NULL", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-2, "hr"), row(-2, "x"), row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); + + sql( + "UPDATE %s t SET id = 1 WHERE " + + "EXISTS (SELECT 1 FROM updated_id ui WHERE t.id = ui.value) AND " + + "EXISTS (SELECT 1 FROM updated_dep ud WHERE t.dep = ud.value)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-2, "x"), row(1, "hr"), row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); + } + + @TestTemplate + public void testUpdateWithNotExistsSubquery() { + createAndInitTable("id INT, dep STRING"); + + append( + tableName, + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hardware\" }\n" + + "{ \"id\": null, \"dep\": \"hr\" }"); + createBranchIfNeeded(); + + createOrReplaceView("updated_id", Arrays.asList(-1, -2, null), Encoders.INT()); + createOrReplaceView("updated_dep", Arrays.asList("hr", "software"), Encoders.STRING()); + + sql( + "UPDATE %s t SET id = -1 WHERE NOT EXISTS (SELECT 1 FROM updated_id u WHERE t.id = u.value + 2)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hardware"), row(-1, "hr"), row(1, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); + + sql( + "UPDATE %s t SET id = 5 WHERE " + + "NOT EXISTS (SELECT 1 FROM updated_id u WHERE t.id = u.value) OR " + + "t.id = 1", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hardware"), row(-1, "hr"), row(5, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); + + sql( + "UPDATE %s t SET id = 10 WHERE " + + "NOT EXISTS (SELECT 1 FROM updated_id ui WHERE t.id = ui.value) AND " + + "EXISTS (SELECT 1 FROM updated_dep ud WHERE t.dep = ud.value)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hardware"), row(-1, "hr"), row(10, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); + } + + @TestTemplate + public void testUpdateWithScalarSubquery() { + createAndInitTable("id INT, dep STRING"); + + append( + tableName, + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hardware\" }\n" + + "{ \"id\": null, \"dep\": \"hr\" }"); + createBranchIfNeeded(); + + createOrReplaceView("updated_id", Arrays.asList(1, 100, null), Encoders.INT()); + + // TODO: Spark does not support AQE and DPP with aggregates at the moment + withSQLConf( + ImmutableMap.of(SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), "false"), + () -> { + sql( + "UPDATE %s SET id = -1 WHERE id <= (SELECT min(value) FROM updated_id)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + }); + } + + @TestTemplate + public void testUpdateThatRequiresGroupingBeforeWrite() { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + append( + tableName, + "{ \"id\": 0, \"dep\": \"hr\" }\n" + + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hr\" }"); + createBranchIfNeeded(); + + append( + commitTarget(), + "{ \"id\": 0, \"dep\": \"ops\" }\n" + + "{ \"id\": 1, \"dep\": \"ops\" }\n" + + "{ \"id\": 2, \"dep\": \"ops\" }"); + + append( + commitTarget(), + "{ \"id\": 0, \"dep\": \"hr\" }\n" + + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hr\" }"); + + append( + commitTarget(), + "{ \"id\": 0, \"dep\": \"ops\" }\n" + + "{ \"id\": 1, \"dep\": \"ops\" }\n" + + "{ \"id\": 2, \"dep\": \"ops\" }"); + + createOrReplaceView("updated_id", Arrays.asList(1, 100), Encoders.INT()); + + String originalNumOfShufflePartitions = spark.conf().get("spark.sql.shuffle.partitions"); + try { + // set the num of shuffle partitions to 1 to ensure we have only 1 writing task + spark.conf().set("spark.sql.shuffle.partitions", "1"); + + sql("UPDATE %s t SET id = -1 WHERE id IN (SELECT * FROM updated_id)", commitTarget()); + assertThat(spark.table(commitTarget()).count()) + .as("Should have expected num of rows") + .isEqualTo(12L); + } finally { + spark.conf().set("spark.sql.shuffle.partitions", originalNumOfShufflePartitions); + } + } + + @TestTemplate + public void testUpdateWithVectorization() { + createAndInitTable("id INT, dep STRING"); + + append( + tableName, + "{ \"id\": 0, \"dep\": \"hr\" }\n" + + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hr\" }"); + createBranchIfNeeded(); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.VECTORIZATION_ENABLED, "true"), + () -> { + sql("UPDATE %s t SET id = -1", commitTarget()); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(-1, "hr"), row(-1, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); + }); + } + + @TestTemplate + public void testUpdateModifyPartitionSourceField() throws NoSuchTableException { + createAndInitTable("id INT, dep STRING, country STRING"); + + sql("ALTER TABLE %s ADD PARTITION FIELD bucket(4, id)", tableName); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + List ids = Lists.newArrayListWithCapacity(100); + for (int id = 1; id <= 100; id++) { + ids.add(id); + } + + Dataset df1 = + spark + .createDataset(ids, Encoders.INT()) + .withColumnRenamed("value", "id") + .withColumn("dep", lit("hr")) + .withColumn("country", lit("usa")); + df1.coalesce(1).writeTo(tableName).append(); + createBranchIfNeeded(); + + Dataset df2 = + spark + .createDataset(ids, Encoders.INT()) + .withColumnRenamed("value", "id") + .withColumn("dep", lit("software")) + .withColumn("country", lit("usa")); + df2.coalesce(1).writeTo(commitTarget()).append(); + + Dataset df3 = + spark + .createDataset(ids, Encoders.INT()) + .withColumnRenamed("value", "id") + .withColumn("dep", lit("hardware")) + .withColumn("country", lit("usa")); + df3.coalesce(1).writeTo(commitTarget()).append(); + + sql( + "UPDATE %s SET id = -1 WHERE id IN (10, 11, 12, 13, 14, 15, 16, 17, 18, 19)", + commitTarget()); + assertThat(scalarSql("SELECT count(*) FROM %s WHERE id = -1", selectTarget())).isEqualTo(30L); + } + + @TestTemplate + public void testUpdateWithStaticPredicatePushdown() { + createAndInitTable("id INT, dep STRING"); + + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + // add a data file to the 'software' partition + append(tableName, "{ \"id\": 1, \"dep\": \"software\" }"); + createBranchIfNeeded(); + + // add a data file to the 'hr' partition + append(commitTarget(), "{ \"id\": 1, \"dep\": \"hr\" }"); + + Table table = validationCatalog.loadTable(tableIdent); + + Snapshot snapshot = SnapshotUtil.latestSnapshot(table, branch); + String dataFilesCount = snapshot.summary().get(SnapshotSummary.TOTAL_DATA_FILES_PROP); + assertThat(dataFilesCount).as("Must have 2 files before UPDATE").isEqualTo("2"); + + // remove the data file from the 'hr' partition to ensure it is not scanned + DataFile dataFile = Iterables.getOnlyElement(snapshot.addedDataFiles(table.io())); + table.io().deleteFile(dataFile.location()); + + // disable dynamic pruning and rely only on static predicate pushdown + withSQLConf( + ImmutableMap.of( + SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED().key(), "false", + SQLConf.RUNTIME_ROW_LEVEL_OPERATION_GROUP_FILTER_ENABLED().key(), "false"), + () -> { + sql("UPDATE %s SET id = -1 WHERE dep IN ('software') AND id == 1", commitTarget()); + }); + } + + @TestTemplate + public void testUpdateWithInvalidUpdates() { + createAndInitTable( + "id INT, a ARRAY>, m MAP", + "{ \"id\": 0, \"a\": null, \"m\": null }"); + + assertThatThrownBy(() -> sql("UPDATE %s SET a.c1 = 1", commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Updating nested fields is only supported for StructType"); + + assertThatThrownBy(() -> sql("UPDATE %s SET m.key = 'new_key'", commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Updating nested fields is only supported for StructType"); + } + + @TestTemplate + public void testUpdateWithConflictingAssignments() { + createAndInitTable( + "id INT, c STRUCT>", "{ \"id\": 0, \"s\": null }"); + + assertThatThrownBy(() -> sql("UPDATE %s t SET t.id = 1, t.c.n1 = 2, t.id = 2", commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Multiple assignments for 'id'"); + + assertThatThrownBy( + () -> sql("UPDATE %s t SET t.c.n1 = 1, t.id = 2, t.c.n1 = 2", commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Multiple assignments for 'c.n1"); + + assertThatThrownBy( + () -> + sql( + "UPDATE %s SET c.n1 = 1, c = named_struct('n1', 1, 'n2', named_struct('dn1', 1, 'dn2', 2))", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Conflicting assignments for 'c'"); + } + + @TestTemplate + public void testUpdateWithInvalidAssignmentsAnsi() { + createAndInitTable( + "id INT NOT NULL, s STRUCT> NOT NULL", + "{ \"id\": 0, \"s\": { \"n1\": 1, \"n2\": { \"dn1\": 3, \"dn2\": 4 } } }"); + + withSQLConf( + ImmutableMap.of("spark.sql.storeAssignmentPolicy", "ansi"), + () -> { + assertThatThrownBy(() -> sql("UPDATE %s t SET t.id = NULL", commitTarget())) + .isInstanceOf(SparkException.class) + .hasMessageContaining("Null value appeared in non-nullable field"); + + assertThatThrownBy(() -> sql("UPDATE %s t SET t.s.n1 = NULL", commitTarget())) + .isInstanceOf(SparkException.class) + .hasMessageContaining("Null value appeared in non-nullable field"); + + assertThatThrownBy( + () -> sql("UPDATE %s t SET t.s = named_struct('n1', 1)", commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Cannot find data for the output column `s`.`n2`"); + + assertThatThrownBy(() -> sql("UPDATE %s t SET t.s.n1 = 'str'", commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Cannot safely cast"); + + assertThatThrownBy( + () -> + sql( + "UPDATE %s t SET t.s.n2 = named_struct('dn3', 1, 'dn1', 2)", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Cannot find data for the output column `s`.`n2`.`dn2`"); + }); + } + + @TestTemplate + public void testUpdateWithInvalidAssignmentsStrict() { + createAndInitTable( + "id INT NOT NULL, s STRUCT> NOT NULL", + "{ \"id\": 0, \"s\": { \"n1\": 1, \"n2\": { \"dn1\": 3, \"dn2\": 4 } } }"); + + withSQLConf( + ImmutableMap.of("spark.sql.storeAssignmentPolicy", "strict"), + () -> { + assertThatThrownBy(() -> sql("UPDATE %s t SET t.id = NULL", commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Cannot safely cast `id` \"VOID\" to \"INT\""); + + assertThatThrownBy(() -> sql("UPDATE %s t SET t.s.n1 = NULL", commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Cannot safely cast `s`.`n1` \"VOID\" to \"INT\""); + + assertThatThrownBy( + () -> sql("UPDATE %s t SET t.s = named_struct('n1', 1)", commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Cannot find data for the output column"); + + assertThatThrownBy(() -> sql("UPDATE %s t SET t.s.n1 = 'str'", commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Cannot safely cast"); + + assertThatThrownBy( + () -> + sql( + "UPDATE %s t SET t.s.n2 = named_struct('dn3', 1, 'dn1', 2)", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Cannot find data for the output column"); + }); + } + + @TestTemplate + public void testUpdateWithNonDeterministicCondition() { + createAndInitTable("id INT, dep STRING", "{ \"id\": 1, \"dep\": \"hr\" }"); + + assertThatThrownBy( + () -> sql("UPDATE %s SET id = -1 WHERE id = 1 AND rand() > 0.5", commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("The operator expects a deterministic expression"); + } + + @TestTemplate + public void testUpdateOnNonIcebergTableNotSupported() { + createOrReplaceView("testtable", "{ \"c1\": -100, \"c2\": -200 }"); + + assertThatThrownBy(() -> sql("UPDATE %s SET c1 = -1 WHERE c2 = 1", "testtable")) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessage("UPDATE TABLE is not supported temporarily."); + } + + @TestTemplate + public void testUpdateToWAPBranch() { + assumeThat(branch).as("WAP branch only works for table identifier without branch").isNull(); + + createAndInitTable( + "id INT, dep STRING", "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 2, \"dep\": \"a\" }"); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' = 'true')", + tableName, TableProperties.WRITE_AUDIT_PUBLISH_ENABLED); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.WAP_BRANCH, "wap"), + () -> { + sql("UPDATE %s SET dep='hr' WHERE dep='a'", tableName); + assertThat(sql("SELECT * FROM %s WHERE dep='hr'", tableName)) + .as("Should have expected num of rows when reading table") + .hasSize(2); + assertThat(sql("SELECT * FROM %s.branch_wap WHERE dep='hr'", tableName)) + .as("Should have expected num of rows when reading WAP branch") + .hasSize(2); + assertThat(sql("SELECT * FROM %s.branch_main WHERE dep='hr'", tableName)) + .as("Should not modify main branch") + .hasSize(1); + }); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.WAP_BRANCH, "wap"), + () -> { + sql("UPDATE %s SET dep='b' WHERE dep='hr'", tableName); + assertThat(sql("SELECT * FROM %s WHERE dep='b'", tableName)) + .as("Should have expected num of rows when reading table with multiple writes") + .hasSize(2); + assertThat(sql("SELECT * FROM %s.branch_wap WHERE dep='b'", tableName)) + .as("Should have expected num of rows when reading WAP branch with multiple writes") + .hasSize(2); + assertThat(sql("SELECT * FROM %s.branch_main WHERE dep='b'", tableName)) + .as("Should not modify main branch with multiple writes") + .hasSize(0); + }); + } + + @TestTemplate + public void testUpdateToWapBranchWithTableBranchIdentifier() { + assumeThat(branch).as("Test must have branch name part in table identifier").isNotNull(); + + createAndInitTable("id INT, dep STRING", "{ \"id\": 1, \"dep\": \"hr\" }"); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' = 'true')", + tableName, TableProperties.WRITE_AUDIT_PUBLISH_ENABLED); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.WAP_BRANCH, "wap"), + () -> + assertThatThrownBy(() -> sql("UPDATE %s SET dep='hr' WHERE dep='a'", commitTarget())) + .isInstanceOf(ValidationException.class) + .hasMessage( + String.format( + "Cannot write to both branch and WAP branch, but got branch [%s] and WAP branch [wap]", + branch))); + } + + private RowLevelOperationMode mode(Table table) { + String modeName = table.properties().getOrDefault(UPDATE_MODE, UPDATE_MODE_DEFAULT); + return RowLevelOperationMode.fromName(modeName); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestViews.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestViews.java new file mode 100644 index 000000000000..819656a95c78 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestViews.java @@ -0,0 +1,1975 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNoException; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.List; +import java.util.Locale; +import java.util.Random; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.apache.iceberg.IcebergBuild; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.Schema; +import org.apache.iceberg.catalog.Catalog; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.catalog.ViewCatalog; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.view.ImmutableSQLViewRepresentation; +import org.apache.iceberg.view.SQLViewRepresentation; +import org.apache.iceberg.view.View; +import org.apache.iceberg.view.ViewHistoryEntry; +import org.apache.iceberg.view.ViewProperties; +import org.apache.iceberg.view.ViewVersion; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.catalog.SessionCatalog; +import org.assertj.core.api.InstanceOfAssertFactories; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestViews extends ExtensionsTestBase { + private static final Namespace NAMESPACE = Namespace.of("default"); + private final String tableName = "table"; + + @BeforeEach + public void before() { + super.before(); + spark.conf().set("spark.sql.defaultCatalog", catalogName); + sql("USE %s", catalogName); + sql("CREATE NAMESPACE IF NOT EXISTS %s", NAMESPACE); + sql("CREATE TABLE %s (id INT, data STRING)", tableName); + } + + @AfterEach + public void removeTable() { + sql("USE %s", catalogName); + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + public static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.SPARK_WITH_VIEWS.catalogName(), + SparkCatalogConfig.SPARK_WITH_VIEWS.implementation(), + SparkCatalogConfig.SPARK_WITH_VIEWS.properties() + } + }; + } + + @TestTemplate + public void readFromView() throws NoSuchTableException { + insertRows(10); + String viewName = viewName("simpleView"); + String sql = String.format("SELECT id FROM %s", tableName); + + ViewCatalog viewCatalog = viewCatalog(); + + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, viewName)) + .withQuery("spark", sql) + // use non-existing column name to make sure only the SQL definition for spark is loaded + .withQuery("trino", String.format("SELECT non_existing FROM %s", tableName)) + .withDefaultNamespace(NAMESPACE) + .withDefaultCatalog(catalogName) + .withSchema(schema(sql)) + .create(); + + List expected = + IntStream.rangeClosed(1, 10).mapToObj(this::row).collect(Collectors.toList()); + + assertThat(sql("SELECT * FROM %s", viewName)) + .hasSize(10) + .containsExactlyInAnyOrderElementsOf(expected); + } + + @TestTemplate + public void readFromTrinoView() throws NoSuchTableException { + insertRows(10); + String viewName = viewName("trinoView"); + String sql = String.format("SELECT id FROM %s", tableName); + + ViewCatalog viewCatalog = viewCatalog(); + + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, viewName)) + .withQuery("trino", sql) + .withDefaultNamespace(NAMESPACE) + .withDefaultCatalog(catalogName) + .withSchema(schema(sql)) + .create(); + + List expected = + IntStream.rangeClosed(1, 10).mapToObj(this::row).collect(Collectors.toList()); + + // there's no explicit view defined for spark, so it will fall back to the defined trino view + assertThat(sql("SELECT * FROM %s", viewName)) + .hasSize(10) + .containsExactlyInAnyOrderElementsOf(expected); + } + + @TestTemplate + public void readFromMultipleViews() throws NoSuchTableException { + insertRows(6); + String viewName = viewName("firstView"); + String secondView = viewName("secondView"); + String viewSQL = String.format("SELECT id FROM %s WHERE id <= 3", tableName); + String secondViewSQL = String.format("SELECT id FROM %s WHERE id > 3", tableName); + + ViewCatalog viewCatalog = viewCatalog(); + + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, viewName)) + .withQuery("spark", viewSQL) + .withDefaultNamespace(NAMESPACE) + .withSchema(schema(viewSQL)) + .create(); + + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, secondView)) + .withQuery("spark", secondViewSQL) + .withDefaultNamespace(NAMESPACE) + .withSchema(schema(secondViewSQL)) + .create(); + + assertThat(sql("SELECT * FROM %s", viewName)) + .hasSize(3) + .containsExactlyInAnyOrder(row(1), row(2), row(3)); + + assertThat(sql("SELECT * FROM %s", secondView)) + .hasSize(3) + .containsExactlyInAnyOrder(row(4), row(5), row(6)); + } + + @TestTemplate + public void readFromViewUsingNonExistingTable() throws NoSuchTableException { + insertRows(10); + String viewName = viewName("viewWithNonExistingTable"); + + ViewCatalog viewCatalog = viewCatalog(); + Schema schema = new Schema(Types.NestedField.required(1, "id", Types.LongType.get())); + + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, viewName)) + .withQuery("spark", "SELECT id FROM non_existing") + .withDefaultNamespace(NAMESPACE) + .withDefaultCatalog(catalogName) + .withSchema(schema) + .create(); + + assertThatThrownBy(() -> sql("SELECT * FROM %s", viewName)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + String.format( + "The table or view `%s`.`%s`.`non_existing` cannot be found", + catalogName, NAMESPACE)); + } + + @TestTemplate + public void readFromViewUsingNonExistingTableColumn() throws NoSuchTableException { + insertRows(10); + String viewName = viewName("viewWithNonExistingColumn"); + + ViewCatalog viewCatalog = viewCatalog(); + Schema schema = new Schema(Types.NestedField.required(1, "non_existing", Types.LongType.get())); + + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, viewName)) + .withQuery("spark", String.format("SELECT non_existing FROM %s", tableName)) + .withDefaultNamespace(NAMESPACE) + .withDefaultCatalog(catalogName) + .withSchema(schema) + .create(); + + assertThatThrownBy(() -> sql("SELECT * FROM %s", viewName)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + "A column, variable, or function parameter with name `non_existing` cannot be resolved"); + } + + @TestTemplate + public void readFromViewUsingInvalidSQL() throws NoSuchTableException { + insertRows(10); + String viewName = viewName("viewWithInvalidSQL"); + + ViewCatalog viewCatalog = viewCatalog(); + Schema schema = tableCatalog().loadTable(TableIdentifier.of(NAMESPACE, tableName)).schema(); + + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, viewName)) + .withQuery("spark", "invalid SQL") + .withDefaultNamespace(NAMESPACE) + .withDefaultCatalog(catalogName) + .withSchema(schema) + .create(); + + assertThatThrownBy(() -> sql("SELECT * FROM %s", viewName)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining(String.format("Invalid view name: %s", viewName)); + } + + @TestTemplate + public void readFromViewWithStaleSchema() throws NoSuchTableException { + insertRows(10); + String viewName = viewName("staleView"); + String sql = String.format("SELECT id, data FROM %s", tableName); + + ViewCatalog viewCatalog = viewCatalog(); + + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, viewName)) + .withQuery("spark", sql) + .withDefaultNamespace(NAMESPACE) + .withDefaultCatalog(catalogName) + .withSchema(schema(sql)) + .create(); + + // drop a column the view depends on + // note that this tests `data` because it has an invalid ordinal + sql("ALTER TABLE %s DROP COLUMN data", tableName); + + // reading from the view should now fail + assertThatThrownBy(() -> sql("SELECT * FROM %s", viewName)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + "A column, variable, or function parameter with name `data` cannot be resolved"); + } + + @TestTemplate + public void readFromViewHiddenByTempView() throws NoSuchTableException { + insertRows(10); + String viewName = viewName("viewHiddenByTempView"); + + ViewCatalog viewCatalog = viewCatalog(); + Schema schema = tableCatalog().loadTable(TableIdentifier.of(NAMESPACE, tableName)).schema(); + + sql("CREATE TEMPORARY VIEW %s AS SELECT id FROM %s WHERE id <= 5", viewName, tableName); + + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, viewName)) + .withQuery("spark", String.format("SELECT id FROM %s WHERE id > 5", tableName)) + .withDefaultNamespace(NAMESPACE) + .withDefaultCatalog(catalogName) + .withSchema(schema) + .create(); + + List expected = + IntStream.rangeClosed(1, 5).mapToObj(this::row).collect(Collectors.toList()); + + // returns the results from the TEMP VIEW + assertThat(sql("SELECT * FROM %s", viewName)) + .hasSize(5) + .containsExactlyInAnyOrderElementsOf(expected); + } + + @TestTemplate + public void readFromViewWithGlobalTempView() throws NoSuchTableException { + insertRows(10); + String viewName = viewName("viewWithGlobalTempView"); + String sql = String.format("SELECT id FROM %s WHERE id > 5", tableName); + + ViewCatalog viewCatalog = viewCatalog(); + + sql("CREATE GLOBAL TEMPORARY VIEW %s AS SELECT id FROM %s WHERE id <= 5", viewName, tableName); + + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, viewName)) + .withQuery("spark", sql) + .withDefaultNamespace(NAMESPACE) + .withDefaultCatalog(catalogName) + .withSchema(schema(sql)) + .create(); + + // GLOBAL TEMP VIEWS are stored in a global_temp namespace + assertThat(sql("SELECT * FROM global_temp.%s", viewName)) + .hasSize(5) + .containsExactlyInAnyOrderElementsOf( + IntStream.rangeClosed(1, 5).mapToObj(this::row).collect(Collectors.toList())); + + assertThat(sql("SELECT * FROM %s", viewName)) + .hasSize(5) + .containsExactlyInAnyOrderElementsOf( + IntStream.rangeClosed(6, 10).mapToObj(this::row).collect(Collectors.toList())); + } + + @TestTemplate + public void readFromViewReferencingAnotherView() throws NoSuchTableException { + insertRows(10); + String firstView = viewName("viewBeingReferencedInAnotherView"); + String viewReferencingOtherView = viewName("viewReferencingOtherView"); + String firstSQL = String.format("SELECT id FROM %s WHERE id <= 5", tableName); + String secondSQL = String.format("SELECT id FROM %s WHERE id > 4", firstView); + + ViewCatalog viewCatalog = viewCatalog(); + + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, firstView)) + .withQuery("spark", firstSQL) + .withDefaultNamespace(NAMESPACE) + .withDefaultCatalog(catalogName) + .withSchema(schema(firstSQL)) + .create(); + + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, viewReferencingOtherView)) + .withQuery("spark", secondSQL) + .withDefaultNamespace(NAMESPACE) + .withDefaultCatalog(catalogName) + .withSchema(schema(secondSQL)) + .create(); + + assertThat(sql("SELECT * FROM %s", viewReferencingOtherView)) + .hasSize(1) + .containsExactly(row(5)); + } + + @TestTemplate + public void readFromViewReferencingTempView() throws NoSuchTableException { + insertRows(10); + String tempView = viewName("tempViewBeingReferencedInAnotherView"); + String viewReferencingTempView = viewName("viewReferencingTempView"); + String sql = String.format("SELECT id FROM %s", tempView); + + ViewCatalog viewCatalog = viewCatalog(); + + sql("CREATE TEMPORARY VIEW %s AS SELECT id FROM %s WHERE id <= 5", tempView, tableName); + + // it wouldn't be possible to reference a TEMP VIEW if the view had been created via SQL, + // but this can't be prevented when using the API directly + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, viewReferencingTempView)) + .withQuery("spark", sql) + .withDefaultNamespace(NAMESPACE) + .withDefaultCatalog(catalogName) + .withSchema(schema(sql)) + .create(); + + List expected = + IntStream.rangeClosed(1, 5).mapToObj(this::row).collect(Collectors.toList()); + + assertThat(sql("SELECT * FROM %s", tempView)) + .hasSize(5) + .containsExactlyInAnyOrderElementsOf(expected); + + // reading from a view that references a TEMP VIEW shouldn't be possible + assertThatThrownBy(() -> sql("SELECT * FROM %s", viewReferencingTempView)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("The table or view") + .hasMessageContaining(tempView) + .hasMessageContaining("cannot be found"); + } + + @TestTemplate + public void readFromViewReferencingAnotherViewHiddenByTempView() throws NoSuchTableException { + insertRows(10); + String innerViewName = viewName("inner_view"); + String outerViewName = viewName("outer_view"); + String innerViewSQL = String.format("SELECT * FROM %s WHERE id > 5", tableName); + String outerViewSQL = String.format("SELECT id FROM %s", innerViewName); + + ViewCatalog viewCatalog = viewCatalog(); + + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, innerViewName)) + .withQuery("spark", innerViewSQL) + .withDefaultNamespace(NAMESPACE) + .withDefaultCatalog(catalogName) + .withSchema(schema(innerViewSQL)) + .create(); + + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, outerViewName)) + .withQuery("spark", outerViewSQL) + .withDefaultNamespace(NAMESPACE) + .withDefaultCatalog(catalogName) + .withSchema(schema(outerViewSQL)) + .create(); + + // create a temporary view that conflicts with the inner view to verify the inner name is + // resolved using the catalog and namespace defaults from the outer view + sql("CREATE TEMPORARY VIEW %s AS SELECT id FROM %s WHERE id <= 5", innerViewName, tableName); + + // ensure that the inner view resolution uses the view namespace and catalog + sql("USE spark_catalog"); + + List tempViewRows = + IntStream.rangeClosed(1, 5).mapToObj(this::row).collect(Collectors.toList()); + + assertThat(sql("SELECT * FROM %s", innerViewName)) + .hasSize(5) + .containsExactlyInAnyOrderElementsOf(tempViewRows); + + List expectedViewRows = + IntStream.rangeClosed(6, 10).mapToObj(this::row).collect(Collectors.toList()); + + assertThat(sql("SELECT * FROM %s.%s.%s", catalogName, NAMESPACE, outerViewName)) + .hasSize(5) + .containsExactlyInAnyOrderElementsOf(expectedViewRows); + } + + @TestTemplate + public void readFromViewReferencingGlobalTempView() throws NoSuchTableException { + insertRows(10); + String globalTempView = viewName("globalTempViewBeingReferenced"); + String viewReferencingTempView = viewName("viewReferencingGlobalTempView"); + + ViewCatalog viewCatalog = viewCatalog(); + Schema schema = tableCatalog().loadTable(TableIdentifier.of(NAMESPACE, tableName)).schema(); + + sql( + "CREATE GLOBAL TEMPORARY VIEW %s AS SELECT id FROM %s WHERE id <= 5", + globalTempView, tableName); + + // it wouldn't be possible to reference a GLOBAL TEMP VIEW if the view had been created via SQL, + // but this can't be prevented when using the API directly + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, viewReferencingTempView)) + .withQuery("spark", String.format("SELECT id FROM global_temp.%s", globalTempView)) + .withDefaultNamespace(NAMESPACE) + .withDefaultCatalog(catalogName) + .withSchema(schema) + .create(); + + List expected = + IntStream.rangeClosed(1, 5).mapToObj(this::row).collect(Collectors.toList()); + + assertThat(sql("SELECT * FROM global_temp.%s", globalTempView)) + .hasSize(5) + .containsExactlyInAnyOrderElementsOf(expected); + + // reading from a view that references a GLOBAL TEMP VIEW shouldn't be possible + assertThatThrownBy(() -> sql("SELECT * FROM %s", viewReferencingTempView)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("The table or view") + .hasMessageContaining(globalTempView) + .hasMessageContaining("cannot be found"); + } + + @TestTemplate + public void readFromViewReferencingTempFunction() throws NoSuchTableException { + insertRows(10); + String viewName = viewName("viewReferencingTempFunction"); + String functionName = viewName("test_avg"); + String sql = String.format("SELECT %s(id) FROM %s", functionName, tableName); + sql( + "CREATE TEMPORARY FUNCTION %s AS 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage'", + functionName); + + ViewCatalog viewCatalog = viewCatalog(); + Schema schema = tableCatalog().loadTable(TableIdentifier.of(NAMESPACE, tableName)).schema(); + + // it wouldn't be possible to reference a TEMP FUNCTION if the view had been created via SQL, + // but this can't be prevented when using the API directly + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, viewName)) + .withQuery("spark", sql) + .withDefaultNamespace(NAMESPACE) + .withDefaultCatalog(catalogName) + .withSchema(schema) + .create(); + + assertThat(sql(sql)).hasSize(1).containsExactly(row(5.5)); + + // reading from a view that references a TEMP FUNCTION shouldn't be possible + assertThatThrownBy(() -> sql("SELECT * FROM %s", viewName)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + String.format("The routine %s.%s cannot be found", NAMESPACE, functionName)); + } + + @TestTemplate + public void readFromViewWithCTE() throws NoSuchTableException { + insertRows(10); + String viewName = viewName("viewWithCTE"); + String sql = + String.format( + "WITH max_by_data AS (SELECT max(id) as max FROM %s) " + + "SELECT max, count(1) AS count FROM max_by_data GROUP BY max", + tableName); + + ViewCatalog viewCatalog = viewCatalog(); + + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, viewName)) + .withQuery("spark", sql) + .withDefaultNamespace(NAMESPACE) + .withDefaultCatalog(catalogName) + .withSchema(schema(sql)) + .create(); + + assertThat(sql("SELECT * FROM %s", viewName)).hasSize(1).containsExactly(row(10, 1L)); + } + + @TestTemplate + public void rewriteFunctionIdentifier() { + String viewName = viewName("rewriteFunctionIdentifier"); + String sql = "SELECT iceberg_version() AS version"; + + assertThatThrownBy(() -> sql(sql)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Cannot resolve routine") + .hasMessageContaining("iceberg_version"); + + ViewCatalog viewCatalog = viewCatalog(); + Schema schema = new Schema(Types.NestedField.required(1, "version", Types.StringType.get())); + + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, viewName)) + .withQuery("spark", sql) + .withDefaultNamespace(Namespace.of("system")) + .withDefaultCatalog(catalogName) + .withSchema(schema) + .create(); + + assertThat(sql("SELECT * FROM %s", viewName)) + .hasSize(1) + .containsExactly(row(IcebergBuild.version())); + } + + @TestTemplate + public void builtinFunctionIdentifierNotRewritten() { + String viewName = viewName("builtinFunctionIdentifierNotRewritten"); + String sql = "SELECT trim(' abc ') AS result"; + + ViewCatalog viewCatalog = viewCatalog(); + Schema schema = new Schema(Types.NestedField.required(1, "result", Types.StringType.get())); + + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, viewName)) + .withQuery("spark", sql) + .withDefaultNamespace(Namespace.of("system")) + .withDefaultCatalog(catalogName) + .withSchema(schema) + .create(); + + assertThat(sql("SELECT * FROM %s", viewName)).hasSize(1).containsExactly(row("abc")); + } + + @TestTemplate + public void rewriteFunctionIdentifierWithNamespace() { + String viewName = viewName("rewriteFunctionIdentifierWithNamespace"); + String sql = "SELECT system.bucket(100, 'a') AS bucket_result, 'a' AS value"; + + ViewCatalog viewCatalog = viewCatalog(); + + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, viewName)) + .withQuery("spark", sql) + .withDefaultNamespace(Namespace.of("system")) + .withDefaultCatalog(catalogName) + .withSchema(schema(sql)) + .create(); + + sql("USE spark_catalog"); + + assertThatThrownBy(() -> sql(sql)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Cannot resolve routine") + .hasMessageContaining("`system`.`bucket`"); + + assertThat(sql("SELECT * FROM %s.%s.%s", catalogName, NAMESPACE, viewName)) + .hasSize(1) + .containsExactly(row(50, "a")); + } + + @TestTemplate + public void fullFunctionIdentifier() { + String viewName = viewName("fullFunctionIdentifier"); + String sql = + String.format( + "SELECT %s.system.bucket(100, 'a') AS bucket_result, 'a' AS value", catalogName); + + sql("USE spark_catalog"); + + ViewCatalog viewCatalog = viewCatalog(); + + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, viewName)) + .withQuery("spark", sql) + .withDefaultNamespace(Namespace.of("system")) + .withDefaultCatalog(catalogName) + .withSchema(schema(sql)) + .create(); + + assertThat(sql("SELECT * FROM %s.%s.%s", catalogName, NAMESPACE, viewName)) + .hasSize(1) + .containsExactly(row(50, "a")); + } + + @TestTemplate + public void fullFunctionIdentifierNotRewrittenLoadFailure() { + String viewName = viewName("fullFunctionIdentifierNotRewrittenLoadFailure"); + String sql = "SELECT spark_catalog.system.bucket(100, 'a') AS bucket_result, 'a' AS value"; + + // avoid namespace failures + sql("USE spark_catalog"); + sql("CREATE NAMESPACE IF NOT EXISTS system"); + sql("USE %s", catalogName); + + Schema schema = + new Schema( + Types.NestedField.required(1, "bucket_result", Types.IntegerType.get()), + Types.NestedField.required(2, "value", Types.StringType.get())); + + ViewCatalog viewCatalog = viewCatalog(); + + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, viewName)) + .withQuery("spark", sql) + .withDefaultNamespace(Namespace.of("system")) + .withDefaultCatalog(catalogName) + .withSchema(schema) + .create(); + + // verify the v1 error message + assertThatThrownBy(() -> sql("SELECT * FROM %s", viewName)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("The routine `system`.`bucket` cannot be found"); + } + + private Schema schema(String sql) { + return SparkSchemaUtil.convert(spark.sql(sql).schema()); + } + + private ViewCatalog viewCatalog() { + Catalog icebergCatalog = Spark3Util.loadIcebergCatalog(spark, catalogName); + assertThat(icebergCatalog).isInstanceOf(ViewCatalog.class); + return (ViewCatalog) icebergCatalog; + } + + private Catalog tableCatalog() { + return Spark3Util.loadIcebergCatalog(spark, catalogName); + } + + @TestTemplate + public void renameView() throws NoSuchTableException { + insertRows(10); + String viewName = viewName("originalView"); + String renamedView = viewName("renamedView"); + String sql = String.format("SELECT id FROM %s", tableName); + + ViewCatalog viewCatalog = viewCatalog(); + + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, viewName)) + .withQuery("spark", sql) + .withDefaultNamespace(NAMESPACE) + .withDefaultCatalog(catalogName) + .withSchema(schema(sql)) + .create(); + + sql("ALTER VIEW %s RENAME TO %s", viewName, renamedView); + + List expected = + IntStream.rangeClosed(1, 10).mapToObj(this::row).collect(Collectors.toList()); + assertThat(sql("SELECT * FROM %s", renamedView)) + .hasSize(10) + .containsExactlyInAnyOrderElementsOf(expected); + } + + @TestTemplate + public void renameViewHiddenByTempView() throws NoSuchTableException { + insertRows(10); + String viewName = viewName("originalView"); + String renamedView = viewName("renamedView"); + String sql = String.format("SELECT id FROM %s WHERE id > 5", tableName); + + ViewCatalog viewCatalog = viewCatalog(); + + sql("CREATE TEMPORARY VIEW %s AS SELECT id FROM %s WHERE id <= 5", viewName, tableName); + + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, viewName)) + .withQuery("spark", sql) + .withDefaultNamespace(NAMESPACE) + .withDefaultCatalog(catalogName) + .withSchema(schema(sql)) + .create(); + + // renames the TEMP VIEW + sql("ALTER VIEW %s RENAME TO %s", viewName, renamedView); + assertThat(sql("SELECT * FROM %s", renamedView)) + .hasSize(5) + .containsExactlyInAnyOrderElementsOf( + IntStream.rangeClosed(1, 5).mapToObj(this::row).collect(Collectors.toList())); + + // original view still exists with its name + assertThat(viewCatalog.viewExists(TableIdentifier.of(NAMESPACE, viewName))).isTrue(); + assertThat(viewCatalog.viewExists(TableIdentifier.of(NAMESPACE, renamedView))).isFalse(); + assertThat(sql("SELECT * FROM %s", viewName)) + .hasSize(5) + .containsExactlyInAnyOrderElementsOf( + IntStream.rangeClosed(6, 10).mapToObj(this::row).collect(Collectors.toList())); + + // will rename the Iceberg view + sql("ALTER VIEW %s RENAME TO %s", viewName, renamedView); + assertThat(viewCatalog.viewExists(TableIdentifier.of(NAMESPACE, renamedView))).isTrue(); + } + + @TestTemplate + public void renameViewToDifferentTargetCatalog() { + String viewName = viewName("originalView"); + String renamedView = viewName("renamedView"); + String sql = String.format("SELECT id FROM %s", tableName); + + ViewCatalog viewCatalog = viewCatalog(); + + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, viewName)) + .withQuery("spark", sql) + .withDefaultNamespace(NAMESPACE) + .withDefaultCatalog(catalogName) + .withSchema(schema(sql)) + .create(); + + assertThatThrownBy(() -> sql("ALTER VIEW %s RENAME TO spark_catalog.%s", viewName, renamedView)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + "Cannot move view between catalogs: from=spark_with_views and to=spark_catalog"); + } + + @TestTemplate + public void renameNonExistingView() { + assertThatThrownBy(() -> sql("ALTER VIEW non_existing RENAME TO target")) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("The table or view `non_existing` cannot be found"); + } + + @TestTemplate + public void renameViewTargetAlreadyExistsAsView() { + String viewName = viewName("renameViewSource"); + String target = viewName("renameViewTarget"); + String sql = String.format("SELECT id FROM %s", tableName); + + ViewCatalog viewCatalog = viewCatalog(); + + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, viewName)) + .withQuery("spark", sql) + .withDefaultNamespace(NAMESPACE) + .withDefaultCatalog(catalogName) + .withSchema(schema(sql)) + .create(); + + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, target)) + .withQuery("spark", sql) + .withDefaultNamespace(NAMESPACE) + .withDefaultCatalog(catalogName) + .withSchema(schema(sql)) + .create(); + + assertThatThrownBy(() -> sql("ALTER VIEW %s RENAME TO %s", viewName, target)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + String.format("Cannot create view default.%s because it already exists", target)); + } + + @TestTemplate + public void renameViewTargetAlreadyExistsAsTable() { + String viewName = viewName("renameViewSource"); + String target = viewName("renameViewTarget"); + String sql = String.format("SELECT id FROM %s", tableName); + + ViewCatalog viewCatalog = viewCatalog(); + + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, viewName)) + .withQuery("spark", sql) + .withDefaultNamespace(NAMESPACE) + .withDefaultCatalog(catalogName) + .withSchema(schema(sql)) + .create(); + + sql("CREATE TABLE %s (id INT, data STRING)", target); + assertThatThrownBy(() -> sql("ALTER VIEW %s RENAME TO %s", viewName, target)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + String.format("Cannot create view default.%s because it already exists", target)); + } + + @TestTemplate + public void dropView() { + String viewName = viewName("viewToBeDropped"); + String sql = String.format("SELECT id FROM %s", tableName); + + ViewCatalog viewCatalog = viewCatalog(); + + TableIdentifier identifier = TableIdentifier.of(NAMESPACE, viewName); + viewCatalog + .buildView(identifier) + .withQuery("spark", sql) + .withDefaultNamespace(NAMESPACE) + .withDefaultCatalog(catalogName) + .withSchema(schema(sql)) + .create(); + + assertThat(viewCatalog.viewExists(identifier)).isTrue(); + + sql("DROP VIEW %s", viewName); + assertThat(viewCatalog.viewExists(identifier)).isFalse(); + } + + @TestTemplate + public void dropNonExistingView() { + assertThatThrownBy(() -> sql("DROP VIEW non_existing")) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("The view %s.%s cannot be found", NAMESPACE, "non_existing"); + } + + @TestTemplate + public void dropViewIfExists() { + String viewName = viewName("viewToBeDropped"); + String sql = String.format("SELECT id FROM %s", tableName); + + ViewCatalog viewCatalog = viewCatalog(); + + TableIdentifier identifier = TableIdentifier.of(NAMESPACE, viewName); + viewCatalog + .buildView(identifier) + .withQuery("spark", sql) + .withDefaultNamespace(NAMESPACE) + .withDefaultCatalog(catalogName) + .withSchema(schema(sql)) + .create(); + + assertThat(viewCatalog.viewExists(identifier)).isTrue(); + + sql("DROP VIEW IF EXISTS %s", viewName); + assertThat(viewCatalog.viewExists(identifier)).isFalse(); + + assertThatNoException().isThrownBy(() -> sql("DROP VIEW IF EXISTS %s", viewName)); + } + + /** The purpose of this test is mainly to make sure that normal view deletion isn't messed up */ + @TestTemplate + public void dropGlobalTempView() { + String globalTempView = viewName("globalViewToBeDropped"); + sql("CREATE GLOBAL TEMPORARY VIEW %s AS SELECT id FROM %s", globalTempView, tableName); + assertThat(v1SessionCatalog().getGlobalTempView(globalTempView).isDefined()).isTrue(); + + sql("DROP VIEW global_temp.%s", globalTempView); + assertThat(v1SessionCatalog().getGlobalTempView(globalTempView).isDefined()).isFalse(); + } + + /** The purpose of this test is mainly to make sure that normal view deletion isn't messed up */ + @TestTemplate + public void dropTempView() { + String tempView = viewName("tempViewToBeDropped"); + sql("CREATE TEMPORARY VIEW %s AS SELECT id FROM %s", tempView, tableName); + assertThat(v1SessionCatalog().getTempView(tempView).isDefined()).isTrue(); + + sql("DROP VIEW %s", tempView); + assertThat(v1SessionCatalog().getTempView(tempView).isDefined()).isFalse(); + } + + /** The purpose of this test is mainly to make sure that normal view deletion isn't messed up */ + @TestTemplate + public void dropV1View() { + String v1View = viewName("v1ViewToBeDropped"); + sql("USE spark_catalog"); + sql("CREATE NAMESPACE IF NOT EXISTS %s", NAMESPACE); + sql("CREATE TABLE %s (id INT, data STRING)", tableName); + sql("CREATE VIEW %s AS SELECT id FROM %s", v1View, tableName); + sql("USE %s", catalogName); + assertThat( + v1SessionCatalog() + .tableExists(new org.apache.spark.sql.catalyst.TableIdentifier(v1View))) + .isTrue(); + + sql("DROP VIEW spark_catalog.%s.%s", NAMESPACE, v1View); + assertThat( + v1SessionCatalog() + .tableExists(new org.apache.spark.sql.catalyst.TableIdentifier(v1View))) + .isFalse(); + + sql("USE spark_catalog"); + sql("DROP TABLE IF EXISTS %s", tableName); + } + + private SessionCatalog v1SessionCatalog() { + return spark.sessionState().catalogManager().v1SessionCatalog(); + } + + private String viewName(String viewName) { + return viewName + new Random().nextInt(1000000); + } + + @TestTemplate + public void createViewIfNotExists() { + String viewName = viewName("viewThatAlreadyExists"); + sql("CREATE VIEW %s AS SELECT id FROM %s", viewName, tableName); + + assertThatThrownBy(() -> sql("CREATE VIEW %s AS SELECT id FROM %s", viewName, tableName)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + String.format( + "Cannot create view %s.%s because it already exists", NAMESPACE, viewName)); + + // using IF NOT EXISTS should work + assertThatNoException() + .isThrownBy( + () -> sql("CREATE VIEW IF NOT EXISTS %s AS SELECT id FROM %s", viewName, tableName)); + } + + @TestTemplate + public void createOrReplaceView() throws NoSuchTableException { + insertRows(6); + String viewName = viewName("simpleView"); + + sql("CREATE OR REPLACE VIEW %s AS SELECT id FROM %s WHERE id <= 3", viewName, tableName); + assertThat(sql("SELECT id FROM %s", viewName)) + .hasSize(3) + .containsExactlyInAnyOrder(row(1), row(2), row(3)); + + sql("CREATE OR REPLACE VIEW %s AS SELECT id FROM %s WHERE id > 3", viewName, tableName); + assertThat(sql("SELECT id FROM %s", viewName)) + .hasSize(3) + .containsExactlyInAnyOrder(row(4), row(5), row(6)); + } + + @TestTemplate + public void createViewWithInvalidSQL() { + assertThatThrownBy(() -> sql("CREATE VIEW simpleViewWithInvalidSQL AS invalid SQL")) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Syntax error"); + } + + @TestTemplate + public void createViewReferencingTempView() throws NoSuchTableException { + insertRows(10); + String tempView = viewName("temporaryViewBeingReferencedInAnotherView"); + String viewReferencingTempView = viewName("viewReferencingTemporaryView"); + + sql("CREATE TEMPORARY VIEW %s AS SELECT id FROM %s WHERE id <= 5", tempView, tableName); + + // creating a view that references a TEMP VIEW shouldn't be possible + assertThatThrownBy( + () -> sql("CREATE VIEW %s AS SELECT id FROM %s", viewReferencingTempView, tempView)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + String.format( + "Cannot create view %s.%s.%s", catalogName, NAMESPACE, viewReferencingTempView)) + .hasMessageContaining("that references temporary view:") + .hasMessageContaining(tempView); + } + + @TestTemplate + public void createViewReferencingGlobalTempView() throws NoSuchTableException { + insertRows(10); + String globalTempView = viewName("globalTemporaryViewBeingReferenced"); + String viewReferencingTempView = viewName("viewReferencingGlobalTemporaryView"); + + sql( + "CREATE GLOBAL TEMPORARY VIEW %s AS SELECT id FROM %s WHERE id <= 5", + globalTempView, tableName); + + // creating a view that references a GLOBAL TEMP VIEW shouldn't be possible + assertThatThrownBy( + () -> + sql( + "CREATE VIEW %s AS SELECT id FROM global_temp.%s", + viewReferencingTempView, globalTempView)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + String.format( + "Cannot create view %s.%s.%s", catalogName, NAMESPACE, viewReferencingTempView)) + .hasMessageContaining("that references temporary view:") + .hasMessageContaining(String.format("%s.%s", "global_temp", globalTempView)); + } + + @TestTemplate + public void createViewReferencingTempFunction() { + String viewName = viewName("viewReferencingTemporaryFunction"); + String functionName = viewName("test_avg_func"); + + sql( + "CREATE TEMPORARY FUNCTION %s AS 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage'", + functionName); + + // creating a view that references a TEMP FUNCTION shouldn't be possible + assertThatThrownBy( + () -> sql("CREATE VIEW %s AS SELECT %s(id) FROM %s", viewName, functionName, tableName)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + String.format("Cannot create view %s.%s.%s", catalogName, NAMESPACE, viewName)) + .hasMessageContaining("that references temporary function:") + .hasMessageContaining(functionName); + } + + @TestTemplate + public void createViewReferencingQualifiedTempFunction() { + String viewName = viewName("viewReferencingTemporaryFunction"); + String functionName = viewName("test_avg_func_qualified"); + + sql( + "CREATE TEMPORARY FUNCTION %s AS 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage'", + functionName); + + // TEMP Function can't be referenced using catalog.schema.name + assertThatThrownBy( + () -> + sql( + "CREATE VIEW %s AS SELECT %s.%s.%s(id) FROM %s", + viewName, catalogName, NAMESPACE, functionName, tableName)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Cannot resolve routine") + .hasMessageContaining( + String.format("`%s`.`%s`.`%s`", catalogName, NAMESPACE, functionName)); + + // TEMP Function can't be referenced using schema.name + assertThatThrownBy( + () -> + sql( + "CREATE VIEW %s AS SELECT %s.%s(id) FROM %s", + viewName, NAMESPACE, functionName, tableName)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Cannot resolve routine") + .hasMessageContaining(String.format("`%s`.`%s`", NAMESPACE, functionName)); + } + + @TestTemplate + public void createViewUsingNonExistingTable() { + assertThatThrownBy( + () -> sql("CREATE VIEW viewWithNonExistingTable AS SELECT id FROM non_existing")) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("The table or view `non_existing` cannot be found"); + } + + @TestTemplate + public void createViewWithMismatchedColumnCounts() { + String viewName = viewName("viewWithMismatchedColumnCounts"); + + assertThatThrownBy( + () -> sql("CREATE VIEW %s (id, data) AS SELECT id FROM %s", viewName, tableName)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + String.format("Cannot create view %s.%s.%s", catalogName, NAMESPACE, viewName)) + .hasMessageContaining("not enough data columns") + .hasMessageContaining("View columns: id, data") + .hasMessageContaining("Data columns: id"); + + assertThatThrownBy( + () -> sql("CREATE VIEW %s (id) AS SELECT id, data FROM %s", viewName, tableName)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + String.format("Cannot create view %s.%s.%s", catalogName, NAMESPACE, viewName)) + .hasMessageContaining("too many data columns") + .hasMessageContaining("View columns: id") + .hasMessageContaining("Data columns: id, data"); + } + + @TestTemplate + public void createViewWithColumnAliases() throws NoSuchTableException { + insertRows(6); + String viewName = viewName("viewWithColumnAliases"); + + sql( + "CREATE VIEW %s (new_id COMMENT 'ID', new_data COMMENT 'DATA') AS SELECT id, data FROM %s WHERE id <= 3", + viewName, tableName); + + View view = viewCatalog().loadView(TableIdentifier.of(NAMESPACE, viewName)); + assertThat(view.properties()).containsEntry("spark.query-column-names", "id,data"); + + assertThat(view.schema().columns()).hasSize(2); + Types.NestedField first = view.schema().columns().get(0); + assertThat(first.name()).isEqualTo("new_id"); + assertThat(first.doc()).isEqualTo("ID"); + + Types.NestedField second = view.schema().columns().get(1); + assertThat(second.name()).isEqualTo("new_data"); + assertThat(second.doc()).isEqualTo("DATA"); + + assertThat(sql("SELECT new_id FROM %s", viewName)) + .hasSize(3) + .containsExactlyInAnyOrder(row(1), row(2), row(3)); + + sql("DROP VIEW %s", viewName); + + sql( + "CREATE VIEW %s (new_data, new_id) AS SELECT data, id FROM %s WHERE id <= 3", + viewName, tableName); + + assertThat(sql("SELECT new_id FROM %s", viewName)) + .hasSize(3) + .containsExactlyInAnyOrder(row(1), row(2), row(3)); + } + + @TestTemplate + public void createViewWithDuplicateColumnNames() { + assertThatThrownBy( + () -> + sql( + "CREATE VIEW viewWithDuplicateColumnNames (new_id, new_id) AS SELECT id, id FROM %s WHERE id <= 3", + tableName)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("The column `new_id` already exists"); + } + + @TestTemplate + public void createViewWithDuplicateQueryColumnNames() throws NoSuchTableException { + insertRows(3); + String viewName = viewName("viewWithDuplicateQueryColumnNames"); + String sql = String.format("SELECT id, id FROM %s WHERE id <= 3", tableName); + + // not specifying column aliases in the view should fail + assertThatThrownBy(() -> sql("CREATE VIEW %s AS %s", viewName, sql)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("The column `id` already exists"); + + sql("CREATE VIEW %s (id_one, id_two) AS %s", viewName, sql); + + assertThat(sql("SELECT * FROM %s", viewName)) + .hasSize(3) + .containsExactlyInAnyOrder(row(1, 1), row(2, 2), row(3, 3)); + } + + @TestTemplate + public void createViewWithCTE() throws NoSuchTableException { + insertRows(10); + String viewName = viewName("simpleViewWithCTE"); + String sql = + String.format( + "WITH max_by_data AS (SELECT max(id) as max FROM %s) " + + "SELECT max, count(1) AS count FROM max_by_data GROUP BY max", + tableName); + + sql("CREATE VIEW %s AS %s", viewName, sql); + + assertThat(sql("SELECT * FROM %s", viewName)).hasSize(1).containsExactly(row(10, 1L)); + } + + @TestTemplate + public void createViewWithConflictingNamesForCTEAndTempView() throws NoSuchTableException { + insertRows(10); + String viewName = viewName("viewWithConflictingNamesForCTEAndTempView"); + String cteName = viewName("cteName"); + String sql = + String.format( + "WITH %s AS (SELECT max(id) as max FROM %s) " + + "(SELECT max, count(1) AS count FROM %s GROUP BY max)", + cteName, tableName, cteName); + + // create a CTE and a TEMP VIEW with the same name + sql("CREATE TEMPORARY VIEW %s AS SELECT * from %s", cteName, tableName); + sql("CREATE VIEW %s AS %s", viewName, sql); + + // CTE should take precedence over the TEMP VIEW when data is read + assertThat(sql("SELECT * FROM %s", viewName)).hasSize(1).containsExactly(row(10, 1L)); + } + + @TestTemplate + public void createViewWithCTEReferencingTempView() { + String viewName = viewName("viewWithCTEReferencingTempView"); + String tempViewInCTE = viewName("tempViewInCTE"); + String sql = + String.format( + "WITH max_by_data AS (SELECT max(id) as max FROM %s) " + + "SELECT max, count(1) AS count FROM max_by_data GROUP BY max", + tempViewInCTE); + + sql("CREATE TEMPORARY VIEW %s AS SELECT id FROM %s WHERE ID <= 5", tempViewInCTE, tableName); + + assertThatThrownBy(() -> sql("CREATE VIEW %s AS %s", viewName, sql)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + String.format("Cannot create view %s.%s.%s", catalogName, NAMESPACE, viewName)) + .hasMessageContaining("that references temporary view:") + .hasMessageContaining(tempViewInCTE); + } + + @TestTemplate + public void createViewWithCTEReferencingTempFunction() { + String viewName = viewName("viewWithCTEReferencingTempFunction"); + String functionName = viewName("avg_function_in_cte"); + String sql = + String.format( + "WITH avg_data AS (SELECT %s(id) as avg FROM %s) " + + "SELECT avg, count(1) AS count FROM avg_data GROUP BY max", + functionName, tableName); + + sql( + "CREATE TEMPORARY FUNCTION %s AS 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage'", + functionName); + + assertThatThrownBy(() -> sql("CREATE VIEW %s AS %s", viewName, sql)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + String.format("Cannot create view %s.%s.%s", catalogName, NAMESPACE, viewName)) + .hasMessageContaining("that references temporary function:") + .hasMessageContaining(functionName); + } + + @TestTemplate + public void createViewWithNonExistingQueryColumn() { + assertThatThrownBy( + () -> + sql( + "CREATE VIEW viewWithNonExistingQueryColumn AS SELECT non_existing FROM %s WHERE id <= 3", + tableName)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + "A column, variable, or function parameter with name `non_existing` cannot be resolved"); + } + + @TestTemplate + public void createViewWithSubqueryExpressionUsingTempView() { + String viewName = viewName("viewWithSubqueryExpression"); + String tempView = viewName("simpleTempView"); + String sql = + String.format("SELECT * FROM %s WHERE id = (SELECT id FROM %s)", tableName, tempView); + + sql("CREATE TEMPORARY VIEW %s AS SELECT id from %s WHERE id = 5", tempView, tableName); + + assertThatThrownBy(() -> sql("CREATE VIEW %s AS %s", viewName, sql)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + String.format("Cannot create view %s.%s.%s", catalogName, NAMESPACE, viewName)) + .hasMessageContaining("that references temporary view:") + .hasMessageContaining(tempView); + } + + @TestTemplate + public void createViewWithSubqueryExpressionUsingGlobalTempView() { + String viewName = viewName("simpleViewWithSubqueryExpression"); + String globalTempView = viewName("simpleGlobalTempView"); + String sql = + String.format( + "SELECT * FROM %s WHERE id = (SELECT id FROM global_temp.%s)", + tableName, globalTempView); + + sql( + "CREATE GLOBAL TEMPORARY VIEW %s AS SELECT id from %s WHERE id = 5", + globalTempView, tableName); + + assertThatThrownBy(() -> sql("CREATE VIEW %s AS %s", viewName, sql)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + String.format("Cannot create view %s.%s.%s", catalogName, NAMESPACE, viewName)) + .hasMessageContaining("that references temporary view:") + .hasMessageContaining(String.format("%s.%s", "global_temp", globalTempView)); + } + + @TestTemplate + public void createViewWithSubqueryExpressionUsingTempFunction() { + String viewName = viewName("viewWithSubqueryExpression"); + String functionName = viewName("avg_function_in_subquery"); + String sql = + String.format( + "SELECT * FROM %s WHERE id < (SELECT %s(id) FROM %s)", + tableName, functionName, tableName); + + sql( + "CREATE TEMPORARY FUNCTION %s AS 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage'", + functionName); + + assertThatThrownBy(() -> sql("CREATE VIEW %s AS %s", viewName, sql)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + String.format("Cannot create view %s.%s.%s", catalogName, NAMESPACE, viewName)) + .hasMessageContaining("that references temporary function:") + .hasMessageContaining(functionName); + } + + @TestTemplate + public void createViewWithSubqueryExpressionInFilterThatIsRewritten() + throws NoSuchTableException { + insertRows(5); + String viewName = viewName("viewWithSubqueryExpression"); + String sql = + String.format( + "SELECT id FROM %s WHERE id = (SELECT max(id) FROM %s)", tableName, tableName); + + sql("CREATE VIEW %s AS %s", viewName, sql); + + assertThat(sql("SELECT * FROM %s", viewName)).hasSize(1).containsExactly(row(5)); + + sql("USE spark_catalog"); + + assertThatThrownBy(() -> sql(sql)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining(String.format("The table or view `%s` cannot be found", tableName)); + + // the underlying SQL in the View should be rewritten to have catalog & namespace + assertThat(sql("SELECT * FROM %s.%s.%s", catalogName, NAMESPACE, viewName)) + .hasSize(1) + .containsExactly(row(5)); + } + + @TestTemplate + public void createViewWithSubqueryExpressionInQueryThatIsRewritten() throws NoSuchTableException { + insertRows(3); + String viewName = viewName("viewWithSubqueryExpression"); + String sql = + String.format("SELECT (SELECT max(id) FROM %s) max_id FROM %s", tableName, tableName); + + sql("CREATE VIEW %s AS %s", viewName, sql); + + assertThat(sql("SELECT * FROM %s", viewName)) + .hasSize(3) + .containsExactly(row(3), row(3), row(3)); + + sql("USE spark_catalog"); + + assertThatThrownBy(() -> sql(sql)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining(String.format("The table or view `%s` cannot be found", tableName)); + + // the underlying SQL in the View should be rewritten to have catalog & namespace + assertThat(sql("SELECT * FROM %s.%s.%s", catalogName, NAMESPACE, viewName)) + .hasSize(3) + .containsExactly(row(3), row(3), row(3)); + } + + @TestTemplate + public void describeView() { + String viewName = viewName("describeView"); + + sql("CREATE VIEW %s AS SELECT id, data FROM %s WHERE id <= 3", viewName, tableName); + assertThat(sql("DESCRIBE %s", viewName)) + .containsExactly(row("id", "int", ""), row("data", "string", "")); + } + + @TestTemplate + public void describeExtendedView() { + String viewName = viewName("describeExtendedView"); + String sql = String.format("SELECT id, data FROM %s WHERE id <= 3", tableName); + + sql( + "CREATE VIEW %s (new_id COMMENT 'ID', new_data COMMENT 'DATA') COMMENT 'view comment' AS %s", + viewName, sql); + assertThat(sql("DESCRIBE EXTENDED %s", viewName)) + .contains( + row("new_id", "int", "ID"), + row("new_data", "string", "DATA"), + row("", "", ""), + row("# Detailed View Information", "", ""), + row("Comment", "view comment", ""), + row("View Catalog and Namespace", String.format("%s.%s", catalogName, NAMESPACE), ""), + row("View Query Output Columns", "[id, data]", ""), + row( + "View Properties", + String.format( + "['format-version' = '1', 'location' = '/%s/%s', 'provider' = 'iceberg']", + NAMESPACE, viewName), + "")); + } + + @TestTemplate + public void showViewProperties() { + String viewName = viewName("showViewProps"); + + sql( + "CREATE VIEW %s TBLPROPERTIES ('key1'='val1', 'key2'='val2') AS SELECT id, data FROM %s WHERE id <= 3", + viewName, tableName); + assertThat(sql("SHOW TBLPROPERTIES %s", viewName)) + .contains(row("key1", "val1"), row("key2", "val2")); + } + + @TestTemplate + public void showViewPropertiesByKey() { + String viewName = viewName("showViewPropsByKey"); + + sql("CREATE VIEW %s AS SELECT id, data FROM %s WHERE id <= 3", viewName, tableName); + assertThat(sql("SHOW TBLPROPERTIES %s", viewName)).contains(row("provider", "iceberg")); + + assertThat(sql("SHOW TBLPROPERTIES %s (provider)", viewName)) + .contains(row("provider", "iceberg")); + + assertThat(sql("SHOW TBLPROPERTIES %s (non.existing)", viewName)) + .contains( + row( + "non.existing", + String.format( + "View %s.%s.%s does not have property: non.existing", + catalogName, NAMESPACE, viewName))); + } + + @TestTemplate + public void showViews() throws NoSuchTableException { + insertRows(6); + String sql = String.format("SELECT * from %s", tableName); + String v1 = viewName("v1"); + String prefixV2 = viewName("prefixV2"); + String prefixV3 = viewName("prefixV3"); + String globalViewForListing = viewName("globalViewForListing"); + String tempViewForListing = viewName("tempViewForListing"); + sql("CREATE VIEW %s AS %s", v1, sql); + sql("CREATE VIEW %s AS %s", prefixV2, sql); + sql("CREATE VIEW %s AS %s", prefixV3, sql); + sql("CREATE GLOBAL TEMPORARY VIEW %s AS %s", globalViewForListing, sql); + sql("CREATE TEMPORARY VIEW %s AS %s", tempViewForListing, sql); + + // spark stores temp views case-insensitive by default + Object[] tempView = row("", tempViewForListing.toLowerCase(Locale.ROOT), true); + assertThat(sql("SHOW VIEWS")) + .contains( + row(NAMESPACE.toString(), prefixV2, false), + row(NAMESPACE.toString(), prefixV3, false), + row(NAMESPACE.toString(), v1, false), + tempView); + + assertThat(sql("SHOW VIEWS IN %s", catalogName)) + .contains( + row(NAMESPACE.toString(), prefixV2, false), + row(NAMESPACE.toString(), prefixV3, false), + row(NAMESPACE.toString(), v1, false), + tempView); + + assertThat(sql("SHOW VIEWS IN %s.%s", catalogName, NAMESPACE)) + .contains( + row(NAMESPACE.toString(), prefixV2, false), + row(NAMESPACE.toString(), prefixV3, false), + row(NAMESPACE.toString(), v1, false), + tempView); + + assertThat(sql("SHOW VIEWS LIKE 'pref*'")) + .contains( + row(NAMESPACE.toString(), prefixV2, false), row(NAMESPACE.toString(), prefixV3, false)); + + assertThat(sql("SHOW VIEWS LIKE 'non-existing'")).isEmpty(); + + assertThat(sql("SHOW VIEWS IN spark_catalog.default")).contains(tempView); + + assertThat(sql("SHOW VIEWS IN global_temp")) + .contains( + // spark stores temp views case-insensitive by default + row("global_temp", globalViewForListing.toLowerCase(Locale.ROOT), true), tempView); + + sql("USE spark_catalog"); + assertThat(sql("SHOW VIEWS")).contains(tempView); + + assertThat(sql("SHOW VIEWS IN default")).contains(tempView); + } + + @TestTemplate + public void showViewsWithCurrentNamespace() { + String namespaceOne = "show_views_ns1"; + String namespaceTwo = "show_views_ns2"; + String viewOne = viewName("viewOne"); + String viewTwo = viewName("viewTwo"); + sql("CREATE NAMESPACE IF NOT EXISTS %s", namespaceOne); + sql("CREATE NAMESPACE IF NOT EXISTS %s", namespaceTwo); + + // create one view in each namespace + sql("CREATE VIEW %s.%s AS SELECT * FROM %s.%s", namespaceOne, viewOne, NAMESPACE, tableName); + sql("CREATE VIEW %s.%s AS SELECT * FROM %s.%s", namespaceTwo, viewTwo, NAMESPACE, tableName); + + Object[] v1 = row(namespaceOne, viewOne, false); + Object[] v2 = row(namespaceTwo, viewTwo, false); + + assertThat(sql("SHOW VIEWS IN %s.%s", catalogName, namespaceOne)) + .contains(v1) + .doesNotContain(v2); + sql("USE %s", namespaceOne); + assertThat(sql("SHOW VIEWS")).contains(v1).doesNotContain(v2); + assertThat(sql("SHOW VIEWS LIKE 'viewOne*'")).contains(v1).doesNotContain(v2); + + assertThat(sql("SHOW VIEWS IN %s.%s", catalogName, namespaceTwo)) + .contains(v2) + .doesNotContain(v1); + sql("USE %s", namespaceTwo); + assertThat(sql("SHOW VIEWS")).contains(v2).doesNotContain(v1); + assertThat(sql("SHOW VIEWS LIKE 'viewTwo*'")).contains(v2).doesNotContain(v1); + } + + @TestTemplate + public void showCreateSimpleView() { + String viewName = viewName("showCreateSimpleView"); + String sql = String.format("SELECT id, data FROM %s WHERE id <= 3", tableName); + + sql("CREATE VIEW %s AS %s", viewName, sql); + + String expected = + String.format( + "CREATE VIEW %s.%s.%s (\n" + + " id,\n" + + " data)\n" + + "TBLPROPERTIES (\n" + + " 'format-version' = '1',\n" + + " 'location' = '/%s/%s',\n" + + " 'provider' = 'iceberg')\n" + + "AS\n%s\n", + catalogName, NAMESPACE, viewName, NAMESPACE, viewName, sql); + assertThat(sql("SHOW CREATE TABLE %s", viewName)).containsExactly(row(expected)); + } + + @TestTemplate + public void showCreateComplexView() { + String viewName = viewName("showCreateComplexView"); + String sql = String.format("SELECT id, data FROM %s WHERE id <= 3", tableName); + + sql( + "CREATE VIEW %s (new_id COMMENT 'ID', new_data COMMENT 'DATA')" + + "COMMENT 'view comment' TBLPROPERTIES ('key1'='val1', 'key2'='val2') AS %s", + viewName, sql); + + String expected = + String.format( + "CREATE VIEW %s.%s.%s (\n" + + " new_id COMMENT 'ID',\n" + + " new_data COMMENT 'DATA')\n" + + "COMMENT 'view comment'\n" + + "TBLPROPERTIES (\n" + + " 'format-version' = '1',\n" + + " 'key1' = 'val1',\n" + + " 'key2' = 'val2',\n" + + " 'location' = '/%s/%s',\n" + + " 'provider' = 'iceberg')\n" + + "AS\n%s\n", + catalogName, NAMESPACE, viewName, NAMESPACE, viewName, sql); + assertThat(sql("SHOW CREATE TABLE %s", viewName)).containsExactly(row(expected)); + } + + @TestTemplate + public void alterViewSetProperties() { + String viewName = viewName("viewWithSetProperties"); + + sql("CREATE VIEW %s AS SELECT id FROM %s WHERE id <= 3", viewName, tableName); + + ViewCatalog viewCatalog = viewCatalog(); + assertThat(viewCatalog.loadView(TableIdentifier.of(NAMESPACE, viewName)).properties()) + .doesNotContainKey("key1") + .doesNotContainKey("comment"); + + sql("ALTER VIEW %s SET TBLPROPERTIES ('key1' = 'val1', 'comment' = 'view comment')", viewName); + assertThat(viewCatalog.loadView(TableIdentifier.of(NAMESPACE, viewName)).properties()) + .containsEntry("key1", "val1") + .containsEntry("comment", "view comment"); + + sql("ALTER VIEW %s SET TBLPROPERTIES ('key1' = 'new_val1')", viewName); + assertThat(viewCatalog.loadView(TableIdentifier.of(NAMESPACE, viewName)).properties()) + .containsEntry("key1", "new_val1") + .containsEntry("comment", "view comment"); + } + + @TestTemplate + public void alterViewSetReservedProperties() { + String viewName = viewName("viewWithSetReservedProperties"); + + sql("CREATE VIEW %s AS SELECT id FROM %s WHERE id <= 3", viewName, tableName); + + assertThatThrownBy(() -> sql("ALTER VIEW %s SET TBLPROPERTIES ('provider' = 'val1')", viewName)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + "The feature is not supported: provider is a reserved table property"); + + assertThatThrownBy( + () -> sql("ALTER VIEW %s SET TBLPROPERTIES ('location' = 'random_location')", viewName)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + "The feature is not supported: location is a reserved table property"); + + assertThatThrownBy( + () -> sql("ALTER VIEW %s SET TBLPROPERTIES ('format-version' = '99')", viewName)) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessageContaining("Cannot set reserved property: 'format-version'"); + + assertThatThrownBy( + () -> + sql( + "ALTER VIEW %s SET TBLPROPERTIES ('spark.query-column-names' = 'a,b,c')", + viewName)) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessageContaining("Cannot set reserved property: 'spark.query-column-names'"); + } + + @TestTemplate + public void alterViewUnsetProperties() { + String viewName = viewName("viewWithUnsetProperties"); + sql("CREATE VIEW %s AS SELECT id FROM %s WHERE id <= 3", viewName, tableName); + + ViewCatalog viewCatalog = viewCatalog(); + assertThat(viewCatalog.loadView(TableIdentifier.of(NAMESPACE, viewName)).properties()) + .doesNotContainKey("key1") + .doesNotContainKey("comment"); + + sql("ALTER VIEW %s SET TBLPROPERTIES ('key1' = 'val1', 'comment' = 'view comment')", viewName); + assertThat(viewCatalog.loadView(TableIdentifier.of(NAMESPACE, viewName)).properties()) + .containsEntry("key1", "val1") + .containsEntry("comment", "view comment"); + + sql("ALTER VIEW %s UNSET TBLPROPERTIES ('key1')", viewName); + assertThat(viewCatalog.loadView(TableIdentifier.of(NAMESPACE, viewName)).properties()) + .doesNotContainKey("key1") + .containsEntry("comment", "view comment"); + } + + @TestTemplate + public void alterViewUnsetUnknownProperty() { + String viewName = viewName("viewWithUnsetUnknownProp"); + sql("CREATE VIEW %s AS SELECT id FROM %s WHERE id <= 3", viewName, tableName); + + assertThatThrownBy(() -> sql("ALTER VIEW %s UNSET TBLPROPERTIES ('unknown-key')", viewName)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Cannot remove property that is not set: 'unknown-key'"); + + assertThatNoException() + .isThrownBy( + () -> sql("ALTER VIEW %s UNSET TBLPROPERTIES IF EXISTS ('unknown-key')", viewName)); + } + + @TestTemplate + public void alterViewUnsetReservedProperties() { + String viewName = viewName("viewWithUnsetReservedProperties"); + + sql("CREATE VIEW %s AS SELECT id FROM %s WHERE id <= 3", viewName, tableName); + + assertThatThrownBy(() -> sql("ALTER VIEW %s UNSET TBLPROPERTIES ('provider')", viewName)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + "The feature is not supported: provider is a reserved table property"); + + assertThatThrownBy(() -> sql("ALTER VIEW %s UNSET TBLPROPERTIES ('location')", viewName)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + "The feature is not supported: location is a reserved table property"); + + assertThatThrownBy(() -> sql("ALTER VIEW %s UNSET TBLPROPERTIES ('format-version')", viewName)) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessageContaining("Cannot unset reserved property: 'format-version'"); + + // spark.query-column-names is only used internally, so it technically doesn't exist on a Spark + // VIEW + assertThatThrownBy( + () -> sql("ALTER VIEW %s UNSET TBLPROPERTIES ('spark.query-column-names')", viewName)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Cannot remove property that is not set: 'spark.query-column-names'"); + + assertThatThrownBy( + () -> + sql( + "ALTER VIEW %s UNSET TBLPROPERTIES IF EXISTS ('spark.query-column-names')", + viewName)) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessageContaining("Cannot unset reserved property: 'spark.query-column-names'"); + } + + @TestTemplate + public void createOrReplaceViewWithColumnAliases() throws NoSuchTableException { + insertRows(6); + String viewName = viewName("viewWithColumnAliases"); + + sql( + "CREATE VIEW %s (new_id COMMENT 'ID', new_data COMMENT 'DATA') AS SELECT id, data FROM %s WHERE id <= 3", + viewName, tableName); + + View view = viewCatalog().loadView(TableIdentifier.of(NAMESPACE, viewName)); + assertThat(view.properties()).containsEntry("spark.query-column-names", "id,data"); + + assertThat(view.schema().columns()).hasSize(2); + Types.NestedField first = view.schema().columns().get(0); + assertThat(first.name()).isEqualTo("new_id"); + assertThat(first.doc()).isEqualTo("ID"); + + Types.NestedField second = view.schema().columns().get(1); + assertThat(second.name()).isEqualTo("new_data"); + assertThat(second.doc()).isEqualTo("DATA"); + + assertThat(sql("SELECT new_id FROM %s", viewName)) + .hasSize(3) + .containsExactlyInAnyOrder(row(1), row(2), row(3)); + + sql( + "CREATE OR REPLACE VIEW %s (data2 COMMENT 'new data', id2 COMMENT 'new ID') AS SELECT data, id FROM %s WHERE id <= 3", + viewName, tableName); + + assertThat(sql("SELECT data2, id2 FROM %s", viewName)) + .hasSize(3) + .containsExactlyInAnyOrder(row("2", 1), row("4", 2), row("6", 3)); + + view = viewCatalog().loadView(TableIdentifier.of(NAMESPACE, viewName)); + assertThat(view.properties()).containsEntry("spark.query-column-names", "data,id"); + + assertThat(view.schema().columns()).hasSize(2); + first = view.schema().columns().get(0); + assertThat(first.name()).isEqualTo("data2"); + assertThat(first.doc()).isEqualTo("new data"); + + second = view.schema().columns().get(1); + assertThat(second.name()).isEqualTo("id2"); + assertThat(second.doc()).isEqualTo("new ID"); + } + + @TestTemplate + public void alterViewIsNotSupported() throws NoSuchTableException { + insertRows(6); + String viewName = viewName("alteredView"); + + sql("CREATE VIEW %s AS SELECT id, data FROM %s WHERE id <= 3", viewName, tableName); + + assertThat(sql("SELECT id FROM %s", viewName)) + .hasSize(3) + .containsExactlyInAnyOrder(row(1), row(2), row(3)); + + assertThatThrownBy( + () -> sql("ALTER VIEW %s AS SELECT id FROM %s WHERE id > 3", viewName, tableName)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + "ALTER VIEW AS is not supported. Use CREATE OR REPLACE VIEW instead"); + } + + @TestTemplate + public void createOrReplaceViewKeepsViewHistory() { + String viewName = viewName("viewWithHistoryAfterReplace"); + String sql = String.format("SELECT id, data FROM %s WHERE id <= 3", tableName); + String updatedSql = String.format("SELECT id FROM %s WHERE id > 3", tableName); + + sql( + "CREATE VIEW %s (new_id COMMENT 'some ID', new_data COMMENT 'some data') AS %s", + viewName, sql); + + View view = viewCatalog().loadView(TableIdentifier.of(NAMESPACE, viewName)); + assertThat(view.history()).hasSize(1); + assertThat(view.sqlFor("spark").sql()).isEqualTo(sql); + assertThat(view.currentVersion().versionId()).isEqualTo(1); + assertThat(view.currentVersion().schemaId()).isEqualTo(0); + assertThat(view.schemas()).hasSize(1); + assertThat(view.schema().asStruct()) + .isEqualTo( + new Schema( + Types.NestedField.optional(0, "new_id", Types.IntegerType.get(), "some ID"), + Types.NestedField.optional(1, "new_data", Types.StringType.get(), "some data")) + .asStruct()); + + sql("CREATE OR REPLACE VIEW %s (updated_id COMMENT 'updated ID') AS %s", viewName, updatedSql); + + view = viewCatalog().loadView(TableIdentifier.of(NAMESPACE, viewName)); + assertThat(view.history()).hasSize(2); + assertThat(view.sqlFor("spark").sql()).isEqualTo(updatedSql); + assertThat(view.currentVersion().versionId()).isEqualTo(2); + assertThat(view.currentVersion().schemaId()).isEqualTo(1); + assertThat(view.schemas()).hasSize(2); + assertThat(view.schema().asStruct()) + .isEqualTo( + new Schema( + Types.NestedField.optional( + 0, "updated_id", Types.IntegerType.get(), "updated ID")) + .asStruct()); + } + + @TestTemplate + public void replacingTrinoViewShouldFail() { + String viewName = viewName("trinoView"); + String sql = String.format("SELECT id FROM %s", tableName); + + ViewCatalog viewCatalog = viewCatalog(); + + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, viewName)) + .withQuery("trino", sql) + .withDefaultNamespace(NAMESPACE) + .withDefaultCatalog(catalogName) + .withSchema(schema(sql)) + .create(); + + assertThatThrownBy(() -> sql("CREATE OR REPLACE VIEW %s AS %s", viewName, sql)) + .isInstanceOf(IllegalStateException.class) + .hasMessage( + "Cannot replace view due to loss of view dialects (replace.drop-dialect.allowed=false):\n" + + "Previous dialects: [trino]\n" + + "New dialects: [spark]"); + } + + @TestTemplate + public void replacingTrinoAndSparkViewShouldFail() { + String viewName = viewName("trinoAndSparkView"); + String sql = String.format("SELECT id FROM %s", tableName); + + ViewCatalog viewCatalog = viewCatalog(); + + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, viewName)) + .withQuery("trino", sql) + .withQuery("spark", sql) + .withDefaultNamespace(NAMESPACE) + .withDefaultCatalog(catalogName) + .withSchema(schema(sql)) + .create(); + + assertThatThrownBy(() -> sql("CREATE OR REPLACE VIEW %s AS %s", viewName, sql)) + .isInstanceOf(IllegalStateException.class) + .hasMessage( + "Cannot replace view due to loss of view dialects (replace.drop-dialect.allowed=false):\n" + + "Previous dialects: [trino, spark]\n" + + "New dialects: [spark]"); + } + + @TestTemplate + public void replacingViewWithDialectDropAllowed() { + String viewName = viewName("trinoView"); + String sql = String.format("SELECT id FROM %s", tableName); + + ViewCatalog viewCatalog = viewCatalog(); + + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, viewName)) + .withQuery("trino", sql) + .withDefaultNamespace(NAMESPACE) + .withDefaultCatalog(catalogName) + .withSchema(schema(sql)) + .create(); + + // allowing to drop the trino dialect should replace the view + sql( + "CREATE OR REPLACE VIEW %s TBLPROPERTIES ('%s'='true') AS SELECT id FROM %s", + viewName, ViewProperties.REPLACE_DROP_DIALECT_ALLOWED, tableName); + + View view = viewCatalog.loadView(TableIdentifier.of(NAMESPACE, viewName)); + assertThat(view.currentVersion().representations()) + .hasSize(1) + .first() + .asInstanceOf(InstanceOfAssertFactories.type(SQLViewRepresentation.class)) + .isEqualTo(ImmutableSQLViewRepresentation.builder().dialect("spark").sql(sql).build()); + + // trino view should show up in the view versions & history + assertThat(view.history()).hasSize(2); + assertThat(view.history()).element(0).extracting(ViewHistoryEntry::versionId).isEqualTo(1); + assertThat(view.history()).element(1).extracting(ViewHistoryEntry::versionId).isEqualTo(2); + + assertThat(view.versions()).hasSize(2); + assertThat(view.versions()).element(0).extracting(ViewVersion::versionId).isEqualTo(1); + assertThat(view.versions()).element(1).extracting(ViewVersion::versionId).isEqualTo(2); + + assertThat(Lists.newArrayList(view.versions()).get(0).representations()) + .hasSize(1) + .first() + .asInstanceOf(InstanceOfAssertFactories.type(SQLViewRepresentation.class)) + .isEqualTo(ImmutableSQLViewRepresentation.builder().dialect("trino").sql(sql).build()); + + assertThat(Lists.newArrayList(view.versions()).get(1).representations()) + .hasSize(1) + .first() + .asInstanceOf(InstanceOfAssertFactories.type(SQLViewRepresentation.class)) + .isEqualTo(ImmutableSQLViewRepresentation.builder().dialect("spark").sql(sql).build()); + } + + @TestTemplate + public void createViewWithRecursiveCycle() { + String viewOne = viewName("viewOne"); + String viewTwo = viewName("viewTwo"); + + sql("CREATE VIEW %s AS SELECT * FROM %s", viewOne, tableName); + // viewTwo points to viewOne + sql("CREATE VIEW %s AS SELECT * FROM %s", viewTwo, viewOne); + + // viewOne points to viewTwo points to viewOne, creating a recursive cycle + String view1 = String.format("%s.%s.%s", catalogName, NAMESPACE, viewOne); + String view2 = String.format("%s.%s.%s", catalogName, NAMESPACE, viewTwo); + String cycle = String.format("%s -> %s -> %s", view1, view2, view1); + assertThatThrownBy(() -> sql("CREATE OR REPLACE VIEW %s AS SELECT * FROM %s", viewOne, view2)) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + String.format("Recursive cycle in view detected: %s (cycle: %s)", view1, cycle)); + } + + @TestTemplate + public void createViewWithRecursiveCycleToV1View() { + String viewOne = viewName("view_one"); + String viewTwo = viewName("view_two"); + + sql("CREATE VIEW %s AS SELECT * FROM %s", viewOne, tableName); + // viewTwo points to viewOne + sql("USE spark_catalog"); + sql("CREATE VIEW %s AS SELECT * FROM %s.%s.%s", viewTwo, catalogName, NAMESPACE, viewOne); + + sql("USE %s", catalogName); + // viewOne points to viewTwo points to viewOne, creating a recursive cycle + String view1 = String.format("%s.%s.%s", catalogName, NAMESPACE, viewOne); + String view2 = String.format("%s.%s.%s", "spark_catalog", NAMESPACE, viewTwo); + String cycle = String.format("%s -> %s -> %s", view1, view2, view1); + assertThatThrownBy(() -> sql("CREATE OR REPLACE VIEW %s AS SELECT * FROM %s", viewOne, view2)) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + String.format("Recursive cycle in view detected: %s (cycle: %s)", view1, cycle)); + } + + @TestTemplate + public void createViewWithRecursiveCycleInCTE() { + String viewOne = viewName("viewOne"); + String viewTwo = viewName("viewTwo"); + + sql("CREATE VIEW %s AS SELECT * FROM %s", viewOne, tableName); + // viewTwo points to viewOne + sql("CREATE VIEW %s AS SELECT * FROM %s", viewTwo, viewOne); + + // CTE points to viewTwo + String sql = + String.format( + "WITH max_by_data AS (SELECT max(id) as max FROM %s) " + + "SELECT max, count(1) AS count FROM max_by_data GROUP BY max", + viewTwo); + + // viewOne points to CTE, creating a recursive cycle + String view1 = String.format("%s.%s.%s", catalogName, NAMESPACE, viewOne); + String cycle = String.format("%s -> %s -> %s", view1, viewTwo, view1); + assertThatThrownBy(() -> sql("CREATE OR REPLACE VIEW %s AS %s", viewOne, sql)) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + String.format("Recursive cycle in view detected: %s (cycle: %s)", view1, cycle)); + } + + @TestTemplate + public void createViewWithRecursiveCycleInSubqueryExpression() { + String viewOne = viewName("viewOne"); + String viewTwo = viewName("viewTwo"); + + sql("CREATE VIEW %s AS SELECT * FROM %s", viewOne, tableName); + // viewTwo points to viewOne + sql("CREATE VIEW %s AS SELECT * FROM %s", viewTwo, viewOne); + + // subquery expression points to viewTwo + String sql = + String.format("SELECT * FROM %s WHERE id = (SELECT id FROM %s)", tableName, viewTwo); + + // viewOne points to subquery expression, creating a recursive cycle + String view1 = String.format("%s.%s.%s", catalogName, NAMESPACE, viewOne); + String cycle = String.format("%s -> %s -> %s", view1, viewTwo, view1); + assertThatThrownBy(() -> sql("CREATE OR REPLACE VIEW %s AS %s", viewOne, sql)) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + String.format("Recursive cycle in view detected: %s (cycle: %s)", view1, cycle)); + } + + private void insertRows(int numRows) throws NoSuchTableException { + List records = Lists.newArrayListWithCapacity(numRows); + for (int i = 1; i <= numRows; i++) { + records.add(new SimpleRecord(i, Integer.toString(i * 2))); + } + + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestWriteAborts.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestWriteAborts.java new file mode 100644 index 000000000000..68b17ef36ff3 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestWriteAborts.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.io.IOException; +import java.nio.file.Files; +import java.util.List; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.CatalogProperties; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopFileIO; +import org.apache.iceberg.io.BulkDeletionFailureException; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.io.SupportsBulkOperations; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.SparkException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestWriteAborts extends ExtensionsTestBase { + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + public static Object[][] parameters() { + return new Object[][] { + { + "testhive", + SparkCatalog.class.getName(), + ImmutableMap.of( + "type", + "hive", + CatalogProperties.FILE_IO_IMPL, + CustomFileIO.class.getName(), + "default-namespace", + "default") + }, + { + "testhivebulk", + SparkCatalog.class.getName(), + ImmutableMap.of( + "type", + "hive", + CatalogProperties.FILE_IO_IMPL, + CustomBulkFileIO.class.getName(), + "default-namespace", + "default") + } + }; + } + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testBatchAppend() throws IOException { + String dataLocation = Files.createTempDirectory(temp, "junit").toFile().toString(); + + sql( + "CREATE TABLE %s (id INT, data STRING) " + + "USING iceberg " + + "PARTITIONED BY (data)" + + "TBLPROPERTIES ('%s' '%s')", + tableName, TableProperties.WRITE_DATA_LOCATION, dataLocation); + + List records = + ImmutableList.of( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "a"), + new SimpleRecord(4, "b")); + Dataset inputDF = spark.createDataFrame(records, SimpleRecord.class); + + assertThatThrownBy( + () -> + // incoming records are not ordered by partitions so the job must fail + inputDF + .coalesce(1) + .sortWithinPartitions("id") + .writeTo(tableName) + .option(SparkWriteOptions.USE_TABLE_DISTRIBUTION_AND_ORDERING, "false") + .option(SparkWriteOptions.FANOUT_ENABLED, "false") + .append()) + .isInstanceOf(SparkException.class) + .hasMessageContaining("Encountered records that belong to already closed files"); + + assertEquals("Should be no records", sql("SELECT * FROM %s", tableName), ImmutableList.of()); + + assertEquals( + "Should be no orphan data files", + ImmutableList.of(), + sql( + "CALL %s.system.remove_orphan_files(table => '%s', older_than => %dL, location => '%s')", + catalogName, tableName, System.currentTimeMillis() + 5000, dataLocation)); + } + + public static class CustomFileIO implements FileIO { + + private final FileIO delegate = new HadoopFileIO(new Configuration()); + + public CustomFileIO() {} + + protected FileIO delegate() { + return delegate; + } + + @Override + public InputFile newInputFile(String path) { + return delegate.newInputFile(path); + } + + @Override + public OutputFile newOutputFile(String path) { + return delegate.newOutputFile(path); + } + + @Override + public void deleteFile(String path) { + delegate.deleteFile(path); + } + + @Override + public Map properties() { + return delegate.properties(); + } + + @Override + public void initialize(Map properties) { + delegate.initialize(properties); + } + + @Override + public void close() { + delegate.close(); + } + } + + public static class CustomBulkFileIO extends CustomFileIO implements SupportsBulkOperations { + + public CustomBulkFileIO() {} + + @Override + public void deleteFile(String path) { + throw new UnsupportedOperationException("Only bulk deletes are supported"); + } + + @Override + public void deleteFiles(Iterable paths) throws BulkDeletionFailureException { + for (String path : paths) { + delegate().deleteFile(path); + } + } + } +} diff --git a/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/source/PlanUtils.java b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/source/PlanUtils.java new file mode 100644 index 000000000000..148717e14255 --- /dev/null +++ b/spark/v4.0/spark-extensions/src/test/java/org/apache/iceberg/spark/source/PlanUtils.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Collection; +import java.util.List; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation; +import scala.PartialFunction; +import scala.collection.JavaConverters; +import scala.collection.Seq; + +public class PlanUtils { + private PlanUtils() {} + + public static List collectPushDownFilters( + LogicalPlan logicalPlan) { + return JavaConverters.asJavaCollection(logicalPlan.collectLeaves()).stream() + .flatMap( + plan -> { + if (!(plan instanceof DataSourceV2ScanRelation)) { + return Stream.empty(); + } + + DataSourceV2ScanRelation scanRelation = (DataSourceV2ScanRelation) plan; + if (!(scanRelation.scan() instanceof SparkBatchQueryScan)) { + return Stream.empty(); + } + + SparkBatchQueryScan batchQueryScan = (SparkBatchQueryScan) scanRelation.scan(); + return batchQueryScan.filterExpressions().stream(); + }) + .collect(Collectors.toList()); + } + + public static List collectSparkExpressions( + LogicalPlan logicalPlan, Predicate predicate) { + Seq> list = + logicalPlan.collect( + new PartialFunction>() { + + @Override + public List apply(LogicalPlan plan) { + return JavaConverters.asJavaCollection(plan.expressions()).stream() + .flatMap(expr -> collectSparkExpressions(expr, predicate).stream()) + .collect(Collectors.toList()); + } + + @Override + public boolean isDefinedAt(LogicalPlan plan) { + return true; + } + }); + + return JavaConverters.asJavaCollection(list).stream() + .flatMap(Collection::stream) + .collect(Collectors.toList()); + } + + private static List collectSparkExpressions( + Expression expression, Predicate predicate) { + Seq list = + expression.collect( + new PartialFunction() { + @Override + public Expression apply(Expression expr) { + return expr; + } + + @Override + public boolean isDefinedAt(Expression expr) { + return predicate.test(expr); + } + }); + + return Lists.newArrayList(JavaConverters.asJavaCollection(list)); + } +} diff --git a/spark/v4.0/spark-runtime/LICENSE b/spark/v4.0/spark-runtime/LICENSE new file mode 100644 index 000000000000..1d3e877720d7 --- /dev/null +++ b/spark/v4.0/spark-runtime/LICENSE @@ -0,0 +1,639 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +-------------------------------------------------------------------------------- + +This binary artifact contains Apache Avro. + +Copyright: 2014-2017 The Apache Software Foundation. +Home page: https://parquet.apache.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains the Jackson JSON processor. + +Copyright: 2007-2019 Tatu Saloranta and other contributors +Home page: http://jackson.codehaus.org/ +License: http://www.apache.org/licenses/LICENSE-2.0.txt + +-------------------------------------------------------------------------------- + +This binary artifact contains Paranamer. + +Copyright: 2000-2007 INRIA, France Telecom, 2006-2018 Paul Hammant & ThoughtWorks Inc +Home page: https://github.com/paul-hammant/paranamer +License: https://github.com/paul-hammant/paranamer/blob/master/LICENSE.txt (BSD) + +License text: +| Portions copyright (c) 2006-2018 Paul Hammant & ThoughtWorks Inc +| Portions copyright (c) 2000-2007 INRIA, France Telecom +| All rights reserved. +| +| Redistribution and use in source and binary forms, with or without +| modification, are permitted provided that the following conditions +| are met: +| 1. Redistributions of source code must retain the above copyright +| notice, this list of conditions and the following disclaimer. +| 2. Redistributions in binary form must reproduce the above copyright +| notice, this list of conditions and the following disclaimer in the +| documentation and/or other materials provided with the distribution. +| 3. Neither the name of the copyright holders nor the names of its +| contributors may be used to endorse or promote products derived from +| this software without specific prior written permission. +| +| THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +| AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +| IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +| ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +| LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +| CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +| SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +| INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +| CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +| ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +| THE POSSIBILITY OF SUCH DAMAGE. + +-------------------------------------------------------------------------------- + +This binary artifact contains Apache Parquet. + +Copyright: 2014-2017 The Apache Software Foundation. +Home page: https://parquet.apache.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Apache Thrift. + +Copyright: 2006-2010 The Apache Software Foundation. +Home page: https://thrift.apache.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains code from Daniel Lemire's JavaFastPFOR project. + +Copyright: 2013 Daniel Lemire +Home page: https://github.com/lemire/JavaFastPFOR +License: Apache License Version 2.0 http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains fastutil. + +Copyright: 2002-2014 Sebastiano Vigna +Home page: http://fastutil.di.unimi.it/ +License: http://www.apache.org/licenses/LICENSE-2.0.html + +-------------------------------------------------------------------------------- + +This binary artifact contains Apache ORC. + +Copyright: 2013-2019 The Apache Software Foundation. +Home page: https://orc.apache.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Apache Hive's storage API via ORC. + +Copyright: 2013-2019 The Apache Software Foundation. +Home page: https://hive.apache.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Google protobuf via ORC. + +Copyright: 2008 Google Inc. +Home page: https://developers.google.com/protocol-buffers +License: https://github.com/protocolbuffers/protobuf/blob/master/LICENSE (BSD) + +License text: + +| Copyright 2008 Google Inc. All rights reserved. +| +| Redistribution and use in source and binary forms, with or without +| modification, are permitted provided that the following conditions are +| met: +| +| * Redistributions of source code must retain the above copyright +| notice, this list of conditions and the following disclaimer. +| * Redistributions in binary form must reproduce the above +| copyright notice, this list of conditions and the following disclaimer +| in the documentation and/or other materials provided with the +| distribution. +| * Neither the name of Google Inc. nor the names of its +| contributors may be used to endorse or promote products derived from +| this software without specific prior written permission. +| +| THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +| "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +| LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +| A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +| OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +| SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +| LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +| DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +| THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +| (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +| OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +| +| Code generated by the Protocol Buffer compiler is owned by the owner +| of the input file used when generating it. This code is not +| standalone and requires a support library to be linked with it. This +| support library is itself covered by the above license. + +-------------------------------------------------------------------------------- + +This binary artifact contains Airlift Aircompressor. + +Copyright: 2011-2019 Aircompressor authors. +Home page: https://github.com/airlift/aircompressor +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Airlift Slice. + +Copyright: 2013-2019 Slice authors. +Home page: https://github.com/airlift/slice +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains JetBrains annotations. + +Copyright: 2000-2020 JetBrains s.r.o. +Home page: https://github.com/JetBrains/java-annotations +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains code from Cloudera Kite. + +Copyright: 2013-2017 Cloudera Inc. +Home page: https://kitesdk.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains code from Presto. + +Copyright: 2016 Facebook and contributors +Home page: https://prestodb.io/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Google Guava. + +Copyright: 2006-2019 The Guava Authors +Home page: https://github.com/google/guava +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Google Error Prone Annotations. + +Copyright: Copyright 2011-2019 The Error Prone Authors +Home page: https://github.com/google/error-prone +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains findbugs-annotations by Stephen Connolly. + +Copyright: 2011-2016 Stephen Connolly, Greg Lucas +Home page: https://github.com/stephenc/findbugs-annotations +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Google j2objc Annotations. + +Copyright: Copyright 2012-2018 Google Inc. +Home page: https://github.com/google/j2objc/tree/master/annotations +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains checkerframework checker-qual Annotations. + +Copyright: 2004-2019 the Checker Framework developers +Home page: https://github.com/typetools/checker-framework +License: https://github.com/typetools/checker-framework/blob/master/LICENSE.txt (MIT license) + +License text: +| The annotations are licensed under the MIT License. (The text of this +| license appears below.) More specifically, all the parts of the Checker +| Framework that you might want to include with your own program use the +| MIT License. This is the checker-qual.jar file and all the files that +| appear in it: every file in a qual/ directory, plus utility files such +| as NullnessUtil.java, RegexUtil.java, SignednessUtil.java, etc. +| In addition, the cleanroom implementations of third-party annotations, +| which the Checker Framework recognizes as aliases for its own +| annotations, are licensed under the MIT License. +| +| Permission is hereby granted, free of charge, to any person obtaining a copy +| of this software and associated documentation files (the "Software"), to deal +| in the Software without restriction, including without limitation the rights +| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +| copies of the Software, and to permit persons to whom the Software is +| furnished to do so, subject to the following conditions: +| +| The above copyright notice and this permission notice shall be included in +| all copies or substantial portions of the Software. +| +| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +| IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +| FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +| AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +| LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +| OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +| THE SOFTWARE. + +-------------------------------------------------------------------------------- + +This binary artifact contains Animal Sniffer Annotations. + +Copyright: 2009-2018 codehaus.org +Home page: https://www.mojohaus.org/animal-sniffer/animal-sniffer-annotations/ +License: https://www.mojohaus.org/animal-sniffer/animal-sniffer-annotations/license.html (MIT license) + +License text: +| The MIT License +| +| Copyright (c) 2009 codehaus.org. +| +| Permission is hereby granted, free of charge, to any person obtaining a copy +| of this software and associated documentation files (the "Software"), to deal +| in the Software without restriction, including without limitation the rights +| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +| copies of the Software, and to permit persons to whom the Software is +| furnished to do so, subject to the following conditions: +| +| The above copyright notice and this permission notice shall be included in +| all copies or substantial portions of the Software. +| +| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +| IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +| FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +| AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +| LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +| OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +| THE SOFTWARE. + +-------------------------------------------------------------------------------- + +This binary artifact contains Caffeine by Ben Manes. + +Copyright: 2014-2019 Ben Manes and contributors +Home page: https://github.com/ben-manes/caffeine +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Apache Arrow. + +Copyright: 2016-2019 The Apache Software Foundation. +Home page: https://arrow.apache.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Netty's buffer library. + +Copyright: 2014-2020 The Netty Project +Home page: https://netty.io/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Google FlatBuffers. + +Copyright: 2013-2020 Google Inc. +Home page: https://google.github.io/flatbuffers/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Carrot Search Labs HPPC. + +Copyright: 2002-2019 Carrot Search s.c. +Home page: http://labs.carrotsearch.com/hppc.html +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains code from Apache Lucene via Carrot Search HPPC. + +Copyright: 2011-2020 The Apache Software Foundation. +Home page: https://lucene.apache.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Apache Yetus audience annotations. + +Copyright: 2008-2020 The Apache Software Foundation. +Home page: https://yetus.apache.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains ThreeTen. + +Copyright: 2007-present, Stephen Colebourne & Michael Nascimento Santos. +Home page: https://www.threeten.org/threeten-extra/ +License: https://github.com/ThreeTen/threeten-extra/blob/master/LICENSE.txt (BSD 3-clause) + +License text: + +| All rights reserved. +| +| * Redistribution and use in source and binary forms, with or without +| modification, are permitted provided that the following conditions are met: +| +| * Redistributions of source code must retain the above copyright notice, +| this list of conditions and the following disclaimer. +| +| * Redistributions in binary form must reproduce the above copyright notice, +| this list of conditions and the following disclaimer in the documentation +| and/or other materials provided with the distribution. +| +| * Neither the name of JSR-310 nor the names of its contributors +| may be used to endorse or promote products derived from this software +| without specific prior written permission. +| +| THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +| "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +| LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +| A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +| CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +| EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +| PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +| PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +| LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +| NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +| SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +-------------------------------------------------------------------------------- + +This binary artifact contains code from Project Nessie. + +Copyright: 2020 Dremio Corporation. +Home page: https://projectnessie.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This product includes code from Apache Spark. + +* vectorized reading of definition levels in BaseVectorizedParquetValuesReader.java +* portions of the extensions parser +* casting logic in AssignmentAlignmentSupport +* implementation of SetAccumulator. + +Copyright: 2011-2018 The Apache Software Foundation +Home page: https://spark.apache.org/ +License: https://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This product includes code from Delta Lake. + +* AssignmentAlignmentSupport is an independent development but UpdateExpressionsSupport in Delta was used as a reference. + +Copyright: 2020 The Delta Lake Project Authors. +Home page: https://delta.io/ +License: https://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary includes code from Apache Commons. + +* Core ArrayUtil. + +Copyright: 2020 The Apache Software Foundation +Home page: https://commons.apache.org/ +License: https://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Apache HttpComponents Client. + +Copyright: 1999-2022 The Apache Software Foundation. +Home page: https://hc.apache.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This product includes code from Apache HttpComponents Client. + +* retry and error handling logic in ExponentialHttpRequestRetryStrategy.java + +Copyright: 1999-2022 The Apache Software Foundation. +Home page: https://hc.apache.org/ +License: https://www.apache.org/licenses/LICENSE-2.0 diff --git a/spark/v4.0/spark-runtime/NOTICE b/spark/v4.0/spark-runtime/NOTICE new file mode 100644 index 000000000000..2935bfff8b80 --- /dev/null +++ b/spark/v4.0/spark-runtime/NOTICE @@ -0,0 +1,508 @@ + +Apache Iceberg +Copyright 2017-2024 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + +-------------------------------------------------------------------------------- + +This binary artifact contains code from Kite, developed at Cloudera, Inc. with +the following copyright notice: + +| Copyright 2013 Cloudera Inc. +| +| Licensed under the Apache License, Version 2.0 (the "License"); +| you may not use this file except in compliance with the License. +| You may obtain a copy of the License at +| +| http://www.apache.org/licenses/LICENSE-2.0 +| +| Unless required by applicable law or agreed to in writing, software +| distributed under the License is distributed on an "AS IS" BASIS, +| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +| See the License for the specific language governing permissions and +| limitations under the License. + +-------------------------------------------------------------------------------- + +This binary artifact includes Apache ORC with the following in its NOTICE file: + +| Apache ORC +| Copyright 2013-2019 The Apache Software Foundation +| +| This product includes software developed by The Apache Software +| Foundation (http://www.apache.org/). +| +| This product includes software developed by Hewlett-Packard: +| (c) Copyright [2014-2015] Hewlett-Packard Development Company, L.P + +-------------------------------------------------------------------------------- + +This binary artifact includes Airlift Aircompressor with the following in its +NOTICE file: + +| Snappy Copyright Notices +| ========================= +| +| * Copyright 2011 Dain Sundstrom +| * Copyright 2011, Google Inc. +| +| +| Snappy License +| =============== +| Copyright 2011, Google Inc. +| All rights reserved. +| +| Redistribution and use in source and binary forms, with or without +| modification, are permitted provided that the following conditions are +| met: +| +| * Redistributions of source code must retain the above copyright +| notice, this list of conditions and the following disclaimer. +| * Redistributions in binary form must reproduce the above +| copyright notice, this list of conditions and the following disclaimer +| in the documentation and/or other materials provided with the +| distribution. +| * Neither the name of Google Inc. nor the names of its +| contributors may be used to endorse or promote products derived from +| this software without specific prior written permission. +| +| THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +| "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +| LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +| A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +| OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +| SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +| LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +| DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +| THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +| (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +| OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +-------------------------------------------------------------------------------- + +This binary artifact includes Carrot Search Labs HPPC with the following in its +NOTICE file: + +| ACKNOWLEDGEMENT +| =============== +| +| HPPC borrowed code, ideas or both from: +| +| * Apache Lucene, http://lucene.apache.org/ +| (Apache license) +| * Fastutil, http://fastutil.di.unimi.it/ +| (Apache license) +| * Koloboke, https://github.com/OpenHFT/Koloboke +| (Apache license) + +-------------------------------------------------------------------------------- + +This binary artifact includes Apache Yetus with the following in its NOTICE +file: + +| Apache Yetus +| Copyright 2008-2020 The Apache Software Foundation +| +| This product includes software developed at +| The Apache Software Foundation (https://www.apache.org/). +| +| --- +| Additional licenses for the Apache Yetus Source/Website: +| --- +| +| +| See LICENSE for terms. + +-------------------------------------------------------------------------------- + +This binary artifact includes Google Protobuf with the following copyright +notice: + +| Copyright 2008 Google Inc. All rights reserved. +| +| Redistribution and use in source and binary forms, with or without +| modification, are permitted provided that the following conditions are +| met: +| +| * Redistributions of source code must retain the above copyright +| notice, this list of conditions and the following disclaimer. +| * Redistributions in binary form must reproduce the above +| copyright notice, this list of conditions and the following disclaimer +| in the documentation and/or other materials provided with the +| distribution. +| * Neither the name of Google Inc. nor the names of its +| contributors may be used to endorse or promote products derived from +| this software without specific prior written permission. +| +| THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +| "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +| LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +| A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +| OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +| SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +| LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +| DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +| THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +| (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +| OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +| +| Code generated by the Protocol Buffer compiler is owned by the owner +| of the input file used when generating it. This code is not +| standalone and requires a support library to be linked with it. This +| support library is itself covered by the above license. + +-------------------------------------------------------------------------------- + +This binary artifact includes Apache Arrow with the following in its NOTICE file: + +| Apache Arrow +| Copyright 2016-2019 The Apache Software Foundation +| +| This product includes software developed at +| The Apache Software Foundation (http://www.apache.org/). +| +| This product includes software from the SFrame project (BSD, 3-clause). +| * Copyright (C) 2015 Dato, Inc. +| * Copyright (c) 2009 Carnegie Mellon University. +| +| This product includes software from the Feather project (Apache 2.0) +| https://github.com/wesm/feather +| +| This product includes software from the DyND project (BSD 2-clause) +| https://github.com/libdynd +| +| This product includes software from the LLVM project +| * distributed under the University of Illinois Open Source +| +| This product includes software from the google-lint project +| * Copyright (c) 2009 Google Inc. All rights reserved. +| +| This product includes software from the mman-win32 project +| * Copyright https://code.google.com/p/mman-win32/ +| * Licensed under the MIT License; +| +| This product includes software from the LevelDB project +| * Copyright (c) 2011 The LevelDB Authors. All rights reserved. +| * Use of this source code is governed by a BSD-style license that can be +| * Moved from Kudu http://github.com/cloudera/kudu +| +| This product includes software from the CMake project +| * Copyright 2001-2009 Kitware, Inc. +| * Copyright 2012-2014 Continuum Analytics, Inc. +| * All rights reserved. +| +| This product includes software from https://github.com/matthew-brett/multibuild (BSD 2-clause) +| * Copyright (c) 2013-2016, Matt Terry and Matthew Brett; all rights reserved. +| +| This product includes software from the Ibis project (Apache 2.0) +| * Copyright (c) 2015 Cloudera, Inc. +| * https://github.com/cloudera/ibis +| +| This product includes software from Dremio (Apache 2.0) +| * Copyright (C) 2017-2018 Dremio Corporation +| * https://github.com/dremio/dremio-oss +| +| This product includes software from Google Guava (Apache 2.0) +| * Copyright (C) 2007 The Guava Authors +| * https://github.com/google/guava +| +| This product include software from CMake (BSD 3-Clause) +| * CMake - Cross Platform Makefile Generator +| * Copyright 2000-2019 Kitware, Inc. and Contributors +| +| The web site includes files generated by Jekyll. +| +| -------------------------------------------------------------------------------- +| +| This product includes code from Apache Kudu, which includes the following in +| its NOTICE file: +| +| Apache Kudu +| Copyright 2016 The Apache Software Foundation +| +| This product includes software developed at +| The Apache Software Foundation (http://www.apache.org/). +| +| Portions of this software were developed at +| Cloudera, Inc (http://www.cloudera.com/). +| +| -------------------------------------------------------------------------------- +| +| This product includes code from Apache ORC, which includes the following in +| its NOTICE file: +| +| Apache ORC +| Copyright 2013-2019 The Apache Software Foundation +| +| This product includes software developed by The Apache Software +| Foundation (http://www.apache.org/). +| +| This product includes software developed by Hewlett-Packard: +| (c) Copyright [2014-2015] Hewlett-Packard Development Company, L.P + +-------------------------------------------------------------------------------- + +This binary artifact includes Netty buffers with the following in its NOTICE +file: + +| The Netty Project +| ================= +| +| Please visit the Netty web site for more information: +| +| * https://netty.io/ +| +| Copyright 2014 The Netty Project +| +| The Netty Project licenses this file to you under the Apache License, +| version 2.0 (the "License"); you may not use this file except in compliance +| with the License. You may obtain a copy of the License at: +| +| http://www.apache.org/licenses/LICENSE-2.0 +| +| Unless required by applicable law or agreed to in writing, software +| distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +| WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +| License for the specific language governing permissions and limitations +| under the License. +| +| Also, please refer to each LICENSE..txt file, which is located in +| the 'license' directory of the distribution file, for the license terms of the +| components that this product depends on. +| +| ------------------------------------------------------------------------------- +| This product contains the extensions to Java Collections Framework which has +| been derived from the works by JSR-166 EG, Doug Lea, and Jason T. Greene: +| +| * LICENSE: +| * license/LICENSE.jsr166y.txt (Public Domain) +| * HOMEPAGE: +| * http://gee.cs.oswego.edu/cgi-bin/viewcvs.cgi/jsr166/ +| * http://viewvc.jboss.org/cgi-bin/viewvc.cgi/jbosscache/experimental/jsr166/ +| +| This product contains a modified version of Robert Harder's Public Domain +| Base64 Encoder and Decoder, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.base64.txt (Public Domain) +| * HOMEPAGE: +| * http://iharder.sourceforge.net/current/java/base64/ +| +| This product contains a modified portion of 'Webbit', an event based +| WebSocket and HTTP server, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.webbit.txt (BSD License) +| * HOMEPAGE: +| * https://github.com/joewalnes/webbit +| +| This product contains a modified portion of 'SLF4J', a simple logging +| facade for Java, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.slf4j.txt (MIT License) +| * HOMEPAGE: +| * http://www.slf4j.org/ +| +| This product contains a modified portion of 'Apache Harmony', an open source +| Java SE, which can be obtained at: +| +| * NOTICE: +| * license/NOTICE.harmony.txt +| * LICENSE: +| * license/LICENSE.harmony.txt (Apache License 2.0) +| * HOMEPAGE: +| * http://archive.apache.org/dist/harmony/ +| +| This product contains a modified portion of 'jbzip2', a Java bzip2 compression +| and decompression library written by Matthew J. Francis. It can be obtained at: +| +| * LICENSE: +| * license/LICENSE.jbzip2.txt (MIT License) +| * HOMEPAGE: +| * https://code.google.com/p/jbzip2/ +| +| This product contains a modified portion of 'libdivsufsort', a C API library to construct +| the suffix array and the Burrows-Wheeler transformed string for any input string of +| a constant-size alphabet written by Yuta Mori. It can be obtained at: +| +| * LICENSE: +| * license/LICENSE.libdivsufsort.txt (MIT License) +| * HOMEPAGE: +| * https://github.com/y-256/libdivsufsort +| +| This product contains a modified portion of Nitsan Wakart's 'JCTools', Java Concurrency Tools for the JVM, +| which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.jctools.txt (ASL2 License) +| * HOMEPAGE: +| * https://github.com/JCTools/JCTools +| +| This product optionally depends on 'JZlib', a re-implementation of zlib in +| pure Java, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.jzlib.txt (BSD style License) +| * HOMEPAGE: +| * http://www.jcraft.com/jzlib/ +| +| This product optionally depends on 'Compress-LZF', a Java library for encoding and +| decoding data in LZF format, written by Tatu Saloranta. It can be obtained at: +| +| * LICENSE: +| * license/LICENSE.compress-lzf.txt (Apache License 2.0) +| * HOMEPAGE: +| * https://github.com/ning/compress +| +| This product optionally depends on 'lz4', a LZ4 Java compression +| and decompression library written by Adrien Grand. It can be obtained at: +| +| * LICENSE: +| * license/LICENSE.lz4.txt (Apache License 2.0) +| * HOMEPAGE: +| * https://github.com/jpountz/lz4-java +| +| This product optionally depends on 'lzma-java', a LZMA Java compression +| and decompression library, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.lzma-java.txt (Apache License 2.0) +| * HOMEPAGE: +| * https://github.com/jponge/lzma-java +| +| This product contains a modified portion of 'jfastlz', a Java port of FastLZ compression +| and decompression library written by William Kinney. It can be obtained at: +| +| * LICENSE: +| * license/LICENSE.jfastlz.txt (MIT License) +| * HOMEPAGE: +| * https://code.google.com/p/jfastlz/ +| +| This product contains a modified portion of and optionally depends on 'Protocol Buffers', Google's data +| interchange format, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.protobuf.txt (New BSD License) +| * HOMEPAGE: +| * https://github.com/google/protobuf +| +| This product optionally depends on 'Bouncy Castle Crypto APIs' to generate +| a temporary self-signed X.509 certificate when the JVM does not provide the +| equivalent functionality. It can be obtained at: +| +| * LICENSE: +| * license/LICENSE.bouncycastle.txt (MIT License) +| * HOMEPAGE: +| * http://www.bouncycastle.org/ +| +| This product optionally depends on 'Snappy', a compression library produced +| by Google Inc, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.snappy.txt (New BSD License) +| * HOMEPAGE: +| * https://github.com/google/snappy +| +| This product optionally depends on 'JBoss Marshalling', an alternative Java +| serialization API, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.jboss-marshalling.txt (Apache License 2.0) +| * HOMEPAGE: +| * https://github.com/jboss-remoting/jboss-marshalling +| +| This product optionally depends on 'Caliper', Google's micro- +| benchmarking framework, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.caliper.txt (Apache License 2.0) +| * HOMEPAGE: +| * https://github.com/google/caliper +| +| This product optionally depends on 'Apache Commons Logging', a logging +| framework, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.commons-logging.txt (Apache License 2.0) +| * HOMEPAGE: +| * http://commons.apache.org/logging/ +| +| This product optionally depends on 'Apache Log4J', a logging framework, which +| can be obtained at: +| +| * LICENSE: +| * license/LICENSE.log4j.txt (Apache License 2.0) +| * HOMEPAGE: +| * http://logging.apache.org/log4j/ +| +| This product optionally depends on 'Aalto XML', an ultra-high performance +| non-blocking XML processor, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.aalto-xml.txt (Apache License 2.0) +| * HOMEPAGE: +| * http://wiki.fasterxml.com/AaltoHome +| +| This product contains a modified version of 'HPACK', a Java implementation of +| the HTTP/2 HPACK algorithm written by Twitter. It can be obtained at: +| +| * LICENSE: +| * license/LICENSE.hpack.txt (Apache License 2.0) +| * HOMEPAGE: +| * https://github.com/twitter/hpack +| +| This product contains a modified version of 'HPACK', a Java implementation of +| the HTTP/2 HPACK algorithm written by Cory Benfield. It can be obtained at: +| +| * LICENSE: +| * license/LICENSE.hyper-hpack.txt (MIT License) +| * HOMEPAGE: +| * https://github.com/python-hyper/hpack/ +| +| This product contains a modified version of 'HPACK', a Java implementation of +| the HTTP/2 HPACK algorithm written by Tatsuhiro Tsujikawa. It can be obtained at: +| +| * LICENSE: +| * license/LICENSE.nghttp2-hpack.txt (MIT License) +| * HOMEPAGE: +| * https://github.com/nghttp2/nghttp2/ +| +| This product contains a modified portion of 'Apache Commons Lang', a Java library +| provides utilities for the java.lang API, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.commons-lang.txt (Apache License 2.0) +| * HOMEPAGE: +| * https://commons.apache.org/proper/commons-lang/ +| +| +| This product contains the Maven wrapper scripts from 'Maven Wrapper', that provides an easy way to ensure a user has everything necessary to run the Maven build. +| +| * LICENSE: +| * license/LICENSE.mvn-wrapper.txt (Apache License 2.0) +| * HOMEPAGE: +| * https://github.com/takari/maven-wrapper +| +| This product contains the dnsinfo.h header file, that provides a way to retrieve the system DNS configuration on MacOS. +| This private header is also used by Apple's open source +| mDNSResponder (https://opensource.apple.com/tarballs/mDNSResponder/). +| +| * LICENSE: +| * license/LICENSE.dnsinfo.txt (Apache License 2.0) +| * HOMEPAGE: +| * http://www.opensource.apple.com/source/configd/configd-453.19/dnsinfo/dnsinfo.h + +-------------------------------------------------------------------------------- + +This binary artifact includes Project Nessie with the following in its NOTICE +file: + +| Dremio +| Copyright 2015-2017 Dremio Corporation +| +| This product includes software developed at +| The Apache Software Foundation (http://www.apache.org/). + diff --git a/spark/v4.0/spark-runtime/src/integration/java/org/apache/iceberg/spark/SmokeTest.java b/spark/v4.0/spark-runtime/src/integration/java/org/apache/iceberg/spark/SmokeTest.java new file mode 100644 index 000000000000..ec445774a452 --- /dev/null +++ b/spark/v4.0/spark-runtime/src/integration/java/org/apache/iceberg/spark/SmokeTest.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.IOException; +import java.nio.file.Files; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Table; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.spark.extensions.ExtensionsTestBase; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class SmokeTest extends ExtensionsTestBase { + @AfterEach + public void dropTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + // Run through our Doc's Getting Started Example + // TODO Update doc example so that it can actually be run, modifications were required for this + // test suite to run + @TestTemplate + public void testGettingStarted() throws IOException { + // Creating a table + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + // Writing + sql("INSERT INTO %s VALUES (1, 'a'), (2, 'b'), (3, 'c')", tableName); + assertThat(scalarSql("SELECT COUNT(*) FROM %s", tableName)) + .as("Should have inserted 3 rows") + .isEqualTo(3L); + + sql("DROP TABLE IF EXISTS source PURGE"); + sql( + "CREATE TABLE source (id bigint, data string) USING parquet LOCATION '%s'", + Files.createTempDirectory(temp, "junit")); + sql("INSERT INTO source VALUES (10, 'd'), (11, 'ee')"); + + sql("INSERT INTO %s SELECT id, data FROM source WHERE length(data) = 1", tableName); + assertThat(scalarSql("SELECT COUNT(*) FROM %s", tableName)) + .as("Table should now have 4 rows") + .isEqualTo(4L); + + sql("DROP TABLE IF EXISTS updates PURGE"); + sql( + "CREATE TABLE updates (id bigint, data string) USING parquet LOCATION '%s'", + Files.createTempDirectory(temp, "junit")); + sql("INSERT INTO updates VALUES (1, 'x'), (2, 'x'), (4, 'z')"); + + sql( + "MERGE INTO %s t USING (SELECT * FROM updates) u ON t.id = u.id\n" + + "WHEN MATCHED THEN UPDATE SET t.data = u.data\n" + + "WHEN NOT MATCHED THEN INSERT *", + tableName); + assertThat(scalarSql("SELECT COUNT(*) FROM %s", tableName)) + .as("Table should now have 5 rows") + .isEqualTo(5L); + assertThat(scalarSql("SELECT data FROM %s WHERE id = 1", tableName)) + .as("Record 1 should now have data x") + .isEqualTo("x"); + + // Reading + assertThat( + scalarSql( + "SELECT count(1) as count FROM %s WHERE data = 'x' GROUP BY data ", tableName)) + .as("There should be 2 records with data x") + .isEqualTo(2L); + + // Not supported because of Spark limitation + if (!catalogName.equals("spark_catalog")) { + assertThat(scalarSql("SELECT COUNT(*) FROM %s.snapshots", tableName)) + .as("There should be 3 snapshots") + .isEqualTo(3L); + } + } + + // From Spark DDL Docs section + @TestTemplate + public void testAlterTable() { + sql( + "CREATE TABLE %s (category int, id bigint, data string, ts timestamp) USING iceberg", + tableName); + Table table; + // Add examples + sql("ALTER TABLE %s ADD PARTITION FIELD bucket(16, id)", tableName); + sql("ALTER TABLE %s ADD PARTITION FIELD truncate(data, 4)", tableName); + sql("ALTER TABLE %s ADD PARTITION FIELD years(ts)", tableName); + sql("ALTER TABLE %s ADD PARTITION FIELD bucket(16, category) AS shard", tableName); + table = getTable(); + assertThat(table.spec().fields()).as("Table should have 4 partition fields").hasSize(4); + + // Drop Examples + sql("ALTER TABLE %s DROP PARTITION FIELD bucket(16, id)", tableName); + sql("ALTER TABLE %s DROP PARTITION FIELD truncate(data, 4)", tableName); + sql("ALTER TABLE %s DROP PARTITION FIELD years(ts)", tableName); + sql("ALTER TABLE %s DROP PARTITION FIELD shard", tableName); + + table = getTable(); + assertThat(table.spec().isUnpartitioned()).as("Table should be unpartitioned").isTrue(); + + // Sort order examples + sql("ALTER TABLE %s WRITE ORDERED BY category, id", tableName); + sql("ALTER TABLE %s WRITE ORDERED BY category ASC, id DESC", tableName); + sql("ALTER TABLE %s WRITE ORDERED BY category ASC NULLS LAST, id DESC NULLS FIRST", tableName); + table = getTable(); + assertThat(table.sortOrder().fields()).as("Table should be sorted on 2 fields").hasSize(2); + } + + @TestTemplate + public void testCreateTable() { + sql("DROP TABLE IF EXISTS %s", tableName("first")); + sql("DROP TABLE IF EXISTS %s", tableName("second")); + sql("DROP TABLE IF EXISTS %s", tableName("third")); + + sql( + "CREATE TABLE %s (\n" + + " id bigint COMMENT 'unique id',\n" + + " data string)\n" + + "USING iceberg", + tableName("first")); + getTable("first"); // Table should exist + + sql( + "CREATE TABLE %s (\n" + + " id bigint,\n" + + " data string,\n" + + " category string)\n" + + "USING iceberg\n" + + "PARTITIONED BY (category)", + tableName("second")); + Table second = getTable("second"); + assertThat(second.spec().fields()).as("Should be partitioned on 1 column").hasSize(1); + + sql( + "CREATE TABLE %s (\n" + + " id bigint,\n" + + " data string,\n" + + " category string,\n" + + " ts timestamp)\n" + + "USING iceberg\n" + + "PARTITIONED BY (bucket(16, id), days(ts), category)", + tableName("third")); + Table third = getTable("third"); + assertThat(third.spec().fields()).as("Should be partitioned on 3 columns").hasSize(3); + } + + @TestTemplate + public void showView() { + sql("DROP VIEW IF EXISTS %s", "test"); + sql("CREATE VIEW %s AS SELECT 1 AS id", "test"); + assertThat(sql("SHOW VIEWS")).contains(row("default", "test", false)); + } + + private Table getTable(String name) { + return validationCatalog.loadTable(TableIdentifier.of("default", name)); + } + + private Table getTable() { + return validationCatalog.loadTable(tableIdent); + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/SparkBenchmarkUtil.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/SparkBenchmarkUtil.java new file mode 100644 index 000000000000..b980c39b5bc3 --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/SparkBenchmarkUtil.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.List; +import org.apache.iceberg.Schema; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.expressions.Attribute; +import org.apache.spark.sql.catalyst.expressions.AttributeReference; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; +import org.apache.spark.sql.catalyst.types.DataTypeUtils; +import org.apache.spark.sql.types.StructType; +import scala.collection.JavaConverters; + +public class SparkBenchmarkUtil { + + private SparkBenchmarkUtil() {} + + public static UnsafeProjection projection(Schema expectedSchema, Schema actualSchema) { + StructType struct = SparkSchemaUtil.convert(actualSchema); + + List refs = + JavaConverters.seqAsJavaListConverter(DataTypeUtils.toAttributes(struct)).asJava(); + List attrs = Lists.newArrayListWithExpectedSize(struct.fields().length); + List exprs = Lists.newArrayListWithExpectedSize(struct.fields().length); + + for (AttributeReference ref : refs) { + attrs.add(ref.toAttribute()); + } + + for (Types.NestedField field : expectedSchema.columns()) { + int indexInIterSchema = struct.fieldIndex(field.name()); + exprs.add(refs.get(indexInIterSchema)); + } + + return UnsafeProjection.create( + JavaConverters.asScalaBufferConverter(exprs).asScala().toSeq(), + JavaConverters.asScalaBufferConverter(attrs).asScala().toSeq()); + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/action/DeleteOrphanFilesBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/action/DeleteOrphanFilesBenchmark.java new file mode 100644 index 000000000000..9e3e930803d1 --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/action/DeleteOrphanFilesBenchmark.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.action; + +import static org.apache.spark.sql.functions.lit; + +import java.sql.Timestamp; +import java.util.List; +import java.util.Locale; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import org.apache.iceberg.AppendFiles; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.DeleteOrphanFiles; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.io.Files; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.iceberg.spark.actions.SparkActions; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Timeout; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +/** + * A benchmark that evaluates the performance of remove orphan files action in Spark. + * + *

To run this benchmark for spark-4.0: + * ./gradlew -DsparkVersions=4.0 :iceberg-spark:iceberg-spark-4.0_2.13:jmh + * -PjmhIncludeRegex=DeleteOrphanFilesBenchmark + * -PjmhOutputPath=benchmark/delete-orphan-files-benchmark-results.txt + * + */ +@Fork(1) +@State(Scope.Benchmark) +@Warmup(iterations = 3) +@Measurement(iterations = 5) +@BenchmarkMode(Mode.SingleShotTime) +@Timeout(time = 1000, timeUnit = TimeUnit.HOURS) +public class DeleteOrphanFilesBenchmark { + + private static final String TABLE_NAME = "delete_orphan_perf"; + private static final int NUM_SNAPSHOTS = 1000; + private static final int NUM_FILES = 1000; + + private SparkSession spark; + private final List validAndOrphanPaths = Lists.newArrayList(); + private Table table; + + @Setup + public void setupBench() { + setupSpark(); + initTable(); + appendData(); + addOrphans(); + } + + @TearDown + public void teardownBench() { + tearDownSpark(); + } + + @Benchmark + @Threads(1) + public void testDeleteOrphanFiles(Blackhole blackhole) { + Dataset validAndOrphanPathsDF = + spark + .createDataset(validAndOrphanPaths, Encoders.STRING()) + .withColumnRenamed("value", "file_path") + .withColumn("last_modified", lit(new Timestamp(10000))); + + DeleteOrphanFiles.Result results = + SparkActions.get(spark) + .deleteOrphanFiles(table()) + .compareToFileList(validAndOrphanPathsDF) + .execute(); + blackhole.consume(results); + } + + private void initTable() { + spark.sql( + String.format( + "CREATE TABLE %s(id INT, name STRING)" + + " USING ICEBERG" + + " TBLPROPERTIES ( 'format-version' = '2')", + TABLE_NAME)); + } + + private void appendData() { + String location = table().location(); + PartitionSpec partitionSpec = table().spec(); + + for (int i = 0; i < NUM_SNAPSHOTS; i++) { + AppendFiles appendFiles = table().newFastAppend(); + for (int j = 0; j < NUM_FILES; j++) { + String path = String.format(Locale.ROOT, "%s/path/to/data-%d-%d.parquet", location, i, j); + validAndOrphanPaths.add(path); + DataFile dataFile = + DataFiles.builder(partitionSpec) + .withPath(path) + .withFileSizeInBytes(10) + .withRecordCount(1) + .build(); + appendFiles.appendFile(dataFile); + } + appendFiles.commit(); + } + } + + private void addOrphans() { + String location = table.location(); + // Generate 10% orphan files + int orphanFileCount = (NUM_FILES * NUM_SNAPSHOTS) / 10; + for (int i = 0; i < orphanFileCount; i++) { + validAndOrphanPaths.add( + String.format("%s/path/to/data-%s.parquet", location, UUID.randomUUID())); + } + } + + private Table table() { + if (table == null) { + try { + table = Spark3Util.loadIcebergTable(spark, TABLE_NAME); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + return table; + } + + private String catalogWarehouse() { + return Files.createTempDir().getAbsolutePath() + "/" + UUID.randomUUID() + "/"; + } + + private void setupSpark() { + SparkSession.Builder builder = + SparkSession.builder() + .config("spark.sql.catalog.spark_catalog", SparkSessionCatalog.class.getName()) + .config("spark.sql.catalog.spark_catalog.type", "hadoop") + .config("spark.sql.catalog.spark_catalog.warehouse", catalogWarehouse()) + .master("local"); + spark = builder.getOrCreate(); + } + + private void tearDownSpark() { + spark.stop(); + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/action/IcebergSortCompactionBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/action/IcebergSortCompactionBenchmark.java new file mode 100644 index 000000000000..95bebc7caed4 --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/action/IcebergSortCompactionBenchmark.java @@ -0,0 +1,399 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.action; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.current_date; +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.expr; + +import java.io.IOException; +import java.util.Collections; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.NullOrder; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortDirection; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.SizeBasedFileRewriter; +import org.apache.iceberg.relocated.com.google.common.io.Files; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.iceberg.spark.actions.SparkActions; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.types.DataTypes; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Timeout; + +@Fork(1) +@State(Scope.Benchmark) +@Measurement(iterations = 10) +@BenchmarkMode(Mode.SingleShotTime) +@Timeout(time = 1000, timeUnit = TimeUnit.HOURS) +public class IcebergSortCompactionBenchmark { + + private static final String[] NAMESPACE = new String[] {"default"}; + private static final String NAME = "sortbench"; + private static final Identifier IDENT = Identifier.of(NAMESPACE, NAME); + private static final int NUM_FILES = 8; + private static final long NUM_ROWS = 7500000L; + private static final long UNIQUE_VALUES = NUM_ROWS / 4; + + private final Configuration hadoopConf = initHadoopConf(); + private SparkSession spark; + + @Setup + public void setupBench() { + setupSpark(); + } + + @TearDown + public void teardownBench() { + tearDownSpark(); + } + + @Setup(Level.Iteration) + public void setupIteration() { + initTable(); + appendData(); + } + + @TearDown(Level.Iteration) + public void cleanUpIteration() throws IOException { + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void sortInt() { + SparkActions.get() + .rewriteDataFiles(table()) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .sort( + SortOrder.builderFor(table().schema()) + .sortBy("intCol", SortDirection.ASC, NullOrder.NULLS_FIRST) + .build()) + .execute(); + } + + @Benchmark + @Threads(1) + public void sortInt2() { + SparkActions.get() + .rewriteDataFiles(table()) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .sort( + SortOrder.builderFor(table().schema()) + .sortBy("intCol", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("intCol2", SortDirection.ASC, NullOrder.NULLS_FIRST) + .build()) + .execute(); + } + + @Benchmark + @Threads(1) + public void sortInt3() { + SparkActions.get() + .rewriteDataFiles(table()) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .sort( + SortOrder.builderFor(table().schema()) + .sortBy("intCol", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("intCol2", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("intCol3", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("intCol4", SortDirection.ASC, NullOrder.NULLS_FIRST) + .build()) + .execute(); + } + + @Benchmark + @Threads(1) + public void sortInt4() { + SparkActions.get() + .rewriteDataFiles(table()) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .sort( + SortOrder.builderFor(table().schema()) + .sortBy("intCol", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("intCol2", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("intCol3", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("intCol4", SortDirection.ASC, NullOrder.NULLS_FIRST) + .build()) + .execute(); + } + + @Benchmark + @Threads(1) + public void sortString() { + SparkActions.get() + .rewriteDataFiles(table()) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .sort( + SortOrder.builderFor(table().schema()) + .sortBy("stringCol", SortDirection.ASC, NullOrder.NULLS_FIRST) + .build()) + .execute(); + } + + @Benchmark + @Threads(1) + public void sortFourColumns() { + SparkActions.get() + .rewriteDataFiles(table()) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .sort( + SortOrder.builderFor(table().schema()) + .sortBy("stringCol", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("intCol", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("dateCol", SortDirection.DESC, NullOrder.NULLS_FIRST) + .sortBy("doubleCol", SortDirection.DESC, NullOrder.NULLS_FIRST) + .build()) + .execute(); + } + + @Benchmark + @Threads(1) + public void sortSixColumns() { + SparkActions.get() + .rewriteDataFiles(table()) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .sort( + SortOrder.builderFor(table().schema()) + .sortBy("stringCol", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("intCol", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("dateCol", SortDirection.DESC, NullOrder.NULLS_FIRST) + .sortBy("timestampCol", SortDirection.DESC, NullOrder.NULLS_FIRST) + .sortBy("doubleCol", SortDirection.DESC, NullOrder.NULLS_FIRST) + .sortBy("longCol", SortDirection.DESC, NullOrder.NULLS_FIRST) + .build()) + .execute(); + } + + @Benchmark + @Threads(1) + public void zSortInt() { + SparkActions.get() + .rewriteDataFiles(table()) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .zOrder("intCol") + .execute(); + } + + @Benchmark + @Threads(1) + public void zSortInt2() { + SparkActions.get() + .rewriteDataFiles(table()) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .zOrder("intCol", "intCol2") + .execute(); + } + + @Benchmark + @Threads(1) + public void zSortInt3() { + SparkActions.get() + .rewriteDataFiles(table()) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .zOrder("intCol", "intCol2", "intCol3") + .execute(); + } + + @Benchmark + @Threads(1) + public void zSortInt4() { + SparkActions.get() + .rewriteDataFiles(table()) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .zOrder("intCol", "intCol2", "intCol3", "intCol4") + .execute(); + } + + @Benchmark + @Threads(1) + public void zSortString() { + SparkActions.get() + .rewriteDataFiles(table()) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .zOrder("stringCol") + .execute(); + } + + @Benchmark + @Threads(1) + public void zSortFourColumns() { + SparkActions.get() + .rewriteDataFiles(table()) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .zOrder("stringCol", "intCol", "dateCol", "doubleCol") + .execute(); + } + + @Benchmark + @Threads(1) + public void zSortSixColumns() { + SparkActions.get() + .rewriteDataFiles(table()) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .zOrder("stringCol", "intCol", "dateCol", "timestampCol", "doubleCol", "longCol") + .execute(); + } + + protected Configuration initHadoopConf() { + return new Configuration(); + } + + protected final void initTable() { + Schema schema = + new Schema( + required(1, "longCol", Types.LongType.get()), + required(2, "intCol", Types.IntegerType.get()), + required(3, "intCol2", Types.IntegerType.get()), + required(4, "intCol3", Types.IntegerType.get()), + required(5, "intCol4", Types.IntegerType.get()), + required(6, "floatCol", Types.FloatType.get()), + optional(7, "doubleCol", Types.DoubleType.get()), + optional(8, "dateCol", Types.DateType.get()), + optional(9, "timestampCol", Types.TimestampType.withZone()), + optional(10, "stringCol", Types.StringType.get())); + + SparkSessionCatalog catalog; + try { + catalog = + (SparkSessionCatalog) + Spark3Util.catalogAndIdentifier(spark(), "spark_catalog").catalog(); + catalog.dropTable(IDENT); + catalog.createTable( + IDENT, SparkSchemaUtil.convert(schema), new Transform[0], Collections.emptyMap()); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private void appendData() { + Dataset df = + spark() + .range(0, NUM_ROWS * NUM_FILES, 1, NUM_FILES) + .drop("id") + .withColumn("longCol", new RandomGeneratingUDF(UNIQUE_VALUES).randomLongUDF().apply()) + .withColumn( + "intCol", + new RandomGeneratingUDF(UNIQUE_VALUES) + .randomLongUDF() + .apply() + .cast(DataTypes.IntegerType)) + .withColumn( + "intCol2", + new RandomGeneratingUDF(UNIQUE_VALUES) + .randomLongUDF() + .apply() + .cast(DataTypes.IntegerType)) + .withColumn( + "intCol3", + new RandomGeneratingUDF(UNIQUE_VALUES) + .randomLongUDF() + .apply() + .cast(DataTypes.IntegerType)) + .withColumn( + "intCol4", + new RandomGeneratingUDF(UNIQUE_VALUES) + .randomLongUDF() + .apply() + .cast(DataTypes.IntegerType)) + .withColumn( + "floatCol", + new RandomGeneratingUDF(UNIQUE_VALUES) + .randomLongUDF() + .apply() + .cast(DataTypes.FloatType)) + .withColumn( + "doubleCol", + new RandomGeneratingUDF(UNIQUE_VALUES) + .randomLongUDF() + .apply() + .cast(DataTypes.DoubleType)) + .withColumn("dateCol", date_add(current_date(), col("intCol").mod(NUM_FILES))) + .withColumn("timestampCol", expr("TO_TIMESTAMP(dateCol)")) + .withColumn("stringCol", new RandomGeneratingUDF(UNIQUE_VALUES).randomString().apply()); + writeData(df); + } + + private void writeData(Dataset df) { + df.write().format("iceberg").mode(SaveMode.Append).save(NAME); + } + + protected final Table table() { + try { + return Spark3Util.loadIcebergTable(spark(), NAME); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + protected final SparkSession spark() { + return spark; + } + + protected String getCatalogWarehouse() { + String location = Files.createTempDir().getAbsolutePath() + "/" + UUID.randomUUID() + "/"; + return location; + } + + protected void cleanupFiles() throws IOException { + spark.sql("DROP TABLE IF EXISTS " + NAME); + } + + protected void setupSpark() { + SparkSession.Builder builder = + SparkSession.builder() + .config( + "spark.sql.catalog.spark_catalog", "org.apache.iceberg.spark.SparkSessionCatalog") + .config("spark.sql.catalog.spark_catalog.type", "hadoop") + .config("spark.sql.catalog.spark_catalog.warehouse", getCatalogWarehouse()) + .master("local[*]"); + spark = builder.getOrCreate(); + Configuration sparkHadoopConf = spark.sessionState().newHadoopConf(); + hadoopConf.forEach(entry -> sparkHadoopConf.set(entry.getKey(), entry.getValue())); + } + + protected void tearDownSpark() { + spark.stop(); + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/action/RandomGeneratingUDF.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/action/RandomGeneratingUDF.java new file mode 100644 index 000000000000..d8f9301a7d82 --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/action/RandomGeneratingUDF.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.action; + +import static org.apache.spark.sql.functions.udf; + +import java.io.Serializable; +import java.util.Random; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.RandomUtil; +import org.apache.spark.sql.expressions.UserDefinedFunction; +import org.apache.spark.sql.types.DataTypes; + +class RandomGeneratingUDF implements Serializable { + private final long uniqueValues; + private final Random rand = new Random(); + + RandomGeneratingUDF(long uniqueValues) { + this.uniqueValues = uniqueValues; + } + + UserDefinedFunction randomLongUDF() { + return udf(() -> rand.nextLong() % (uniqueValues / 2), DataTypes.LongType) + .asNondeterministic() + .asNonNullable(); + } + + UserDefinedFunction randomString() { + return udf( + () -> RandomUtil.generatePrimitive(Types.StringType.get(), rand), DataTypes.StringType) + .asNondeterministic() + .asNonNullable(); + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetReadersFlatDataBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetReadersFlatDataBenchmark.java new file mode 100644 index 000000000000..3dbee5dfd0f5 --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetReadersFlatDataBenchmark.java @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.parquet; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import org.apache.avro.generic.GenericData; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.common.DynMethods; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.spark.SparkBenchmarkUtil; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.SparkParquetReaders; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; +import org.apache.spark.sql.execution.datasources.parquet.ParquetReadSupport; +import org.apache.spark.sql.types.StructType; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +/** + * A benchmark that evaluates the performance of reading Parquet data with a flat schema using + * Iceberg and Spark Parquet readers. + * + *

To run this benchmark for spark-4.0: + * ./gradlew -DsparkVersions=4.0 :iceberg-spark:iceberg-spark-4.0_2.13:jmh + * -PjmhIncludeRegex=SparkParquetReadersFlatDataBenchmark + * -PjmhOutputPath=benchmark/spark-parquet-readers-flat-data-benchmark-result.txt + * + */ +@Fork(1) +@State(Scope.Benchmark) +@Warmup(iterations = 3) +@Measurement(iterations = 5) +@BenchmarkMode(Mode.SingleShotTime) +public class SparkParquetReadersFlatDataBenchmark { + + private static final DynMethods.UnboundMethod APPLY_PROJECTION = + DynMethods.builder("apply").impl(UnsafeProjection.class, InternalRow.class).build(); + private static final Schema SCHEMA = + new Schema( + required(1, "longCol", Types.LongType.get()), + required(2, "intCol", Types.IntegerType.get()), + required(3, "floatCol", Types.FloatType.get()), + optional(4, "doubleCol", Types.DoubleType.get()), + optional(5, "decimalCol", Types.DecimalType.of(20, 5)), + optional(6, "dateCol", Types.DateType.get()), + optional(7, "timestampCol", Types.TimestampType.withZone()), + optional(8, "stringCol", Types.StringType.get())); + private static final Schema PROJECTED_SCHEMA = + new Schema( + required(1, "longCol", Types.LongType.get()), + optional(5, "decimalCol", Types.DecimalType.of(20, 5)), + optional(8, "stringCol", Types.StringType.get())); + private static final int NUM_RECORDS = 1000000; + private File dataFile; + + @Setup + public void setupBenchmark() throws IOException { + dataFile = File.createTempFile("parquet-flat-data-benchmark", ".parquet"); + dataFile.delete(); + List records = RandomData.generateList(SCHEMA, NUM_RECORDS, 0L); + try (FileAppender writer = + Parquet.write(Files.localOutput(dataFile)).schema(SCHEMA).named("benchmark").build()) { + writer.addAll(records); + } + } + + @TearDown + public void tearDownBenchmark() { + if (dataFile != null) { + dataFile.delete(); + } + } + + @Benchmark + @Threads(1) + public void readUsingIcebergReader(Blackhole blackHole) throws IOException { + try (CloseableIterable rows = + Parquet.read(Files.localInput(dataFile)) + .project(SCHEMA) + .createReaderFunc(type -> SparkParquetReaders.buildReader(SCHEMA, type)) + .build()) { + + for (InternalRow row : rows) { + blackHole.consume(row); + } + } + } + + @Benchmark + @Threads(1) + public void readUsingIcebergReaderUnsafe(Blackhole blackhole) throws IOException { + try (CloseableIterable rows = + Parquet.read(Files.localInput(dataFile)) + .project(SCHEMA) + .createReaderFunc(type -> SparkParquetReaders.buildReader(SCHEMA, type)) + .build()) { + + Iterable unsafeRows = + Iterables.transform( + rows, APPLY_PROJECTION.bind(SparkBenchmarkUtil.projection(SCHEMA, SCHEMA))::invoke); + + for (InternalRow row : unsafeRows) { + blackhole.consume(row); + } + } + } + + @Benchmark + @Threads(1) + public void readUsingSparkReader(Blackhole blackhole) throws IOException { + StructType sparkSchema = SparkSchemaUtil.convert(SCHEMA); + try (CloseableIterable rows = + Parquet.read(Files.localInput(dataFile)) + .project(SCHEMA) + .readSupport(new ParquetReadSupport()) + .set("org.apache.spark.sql.parquet.row.requested_schema", sparkSchema.json()) + .set("spark.sql.parquet.binaryAsString", "false") + .set("spark.sql.parquet.int96AsTimestamp", "false") + .set("spark.sql.caseSensitive", "false") + .set("spark.sql.parquet.fieldId.write.enabled", "false") + .set("spark.sql.parquet.inferTimestampNTZ.enabled", "false") + .set("spark.sql.legacy.parquet.nanosAsLong", "false") + .callInit() + .build()) { + + for (InternalRow row : rows) { + blackhole.consume(row); + } + } + } + + @Benchmark + @Threads(1) + public void readWithProjectionUsingIcebergReader(Blackhole blackhole) throws IOException { + try (CloseableIterable rows = + Parquet.read(Files.localInput(dataFile)) + .project(PROJECTED_SCHEMA) + .createReaderFunc(type -> SparkParquetReaders.buildReader(PROJECTED_SCHEMA, type)) + .build()) { + + for (InternalRow row : rows) { + blackhole.consume(row); + } + } + } + + @Benchmark + @Threads(1) + public void readWithProjectionUsingIcebergReaderUnsafe(Blackhole blackhole) throws IOException { + try (CloseableIterable rows = + Parquet.read(Files.localInput(dataFile)) + .project(PROJECTED_SCHEMA) + .createReaderFunc(type -> SparkParquetReaders.buildReader(PROJECTED_SCHEMA, type)) + .build()) { + + Iterable unsafeRows = + Iterables.transform( + rows, + APPLY_PROJECTION.bind( + SparkBenchmarkUtil.projection(PROJECTED_SCHEMA, PROJECTED_SCHEMA)) + ::invoke); + + for (InternalRow row : unsafeRows) { + blackhole.consume(row); + } + } + } + + @Benchmark + @Threads(1) + public void readWithProjectionUsingSparkReader(Blackhole blackhole) throws IOException { + StructType sparkSchema = SparkSchemaUtil.convert(PROJECTED_SCHEMA); + try (CloseableIterable rows = + Parquet.read(Files.localInput(dataFile)) + .project(PROJECTED_SCHEMA) + .readSupport(new ParquetReadSupport()) + .set("org.apache.spark.sql.parquet.row.requested_schema", sparkSchema.json()) + .set("spark.sql.parquet.binaryAsString", "false") + .set("spark.sql.parquet.int96AsTimestamp", "false") + .set("spark.sql.caseSensitive", "false") + .set("spark.sql.parquet.inferTimestampNTZ.enabled", "false") + .set("spark.sql.legacy.parquet.nanosAsLong", "false") + .callInit() + .build()) { + + for (InternalRow row : rows) { + blackhole.consume(row); + } + } + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetReadersNestedDataBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetReadersNestedDataBenchmark.java new file mode 100644 index 000000000000..8487988d9e5b --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetReadersNestedDataBenchmark.java @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.parquet; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import org.apache.avro.generic.GenericData; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.common.DynMethods; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.spark.SparkBenchmarkUtil; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.SparkParquetReaders; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; +import org.apache.spark.sql.execution.datasources.parquet.ParquetReadSupport; +import org.apache.spark.sql.types.StructType; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +/** + * A benchmark that evaluates the performance of reading nested Parquet data using Iceberg and Spark + * Parquet readers. + * + *

To run this benchmark for spark-4.0: + * ./gradlew -DsparkVersions=4.0 :iceberg-spark:iceberg-spark-4.0_2.13:jmh + * -PjmhIncludeRegex=SparkParquetReadersNestedDataBenchmark + * -PjmhOutputPath=benchmark/spark-parquet-readers-nested-data-benchmark-result.txt + * + */ +@Fork(1) +@State(Scope.Benchmark) +@Warmup(iterations = 3) +@Measurement(iterations = 5) +@BenchmarkMode(Mode.SingleShotTime) +public class SparkParquetReadersNestedDataBenchmark { + + private static final DynMethods.UnboundMethod APPLY_PROJECTION = + DynMethods.builder("apply").impl(UnsafeProjection.class, InternalRow.class).build(); + private static final Schema SCHEMA = + new Schema( + required(0, "id", Types.LongType.get()), + optional( + 4, + "nested", + Types.StructType.of( + required(1, "col1", Types.StringType.get()), + required(2, "col2", Types.DoubleType.get()), + required(3, "col3", Types.LongType.get())))); + private static final Schema PROJECTED_SCHEMA = + new Schema( + optional(4, "nested", Types.StructType.of(required(1, "col1", Types.StringType.get())))); + private static final int NUM_RECORDS = 1000000; + private File dataFile; + + @Setup + public void setupBenchmark() throws IOException { + dataFile = File.createTempFile("parquet-nested-data-benchmark", ".parquet"); + dataFile.delete(); + List records = RandomData.generateList(SCHEMA, NUM_RECORDS, 0L); + try (FileAppender writer = + Parquet.write(Files.localOutput(dataFile)).schema(SCHEMA).named("benchmark").build()) { + writer.addAll(records); + } + } + + @TearDown + public void tearDownBenchmark() { + if (dataFile != null) { + dataFile.delete(); + } + } + + @Benchmark + @Threads(1) + public void readUsingIcebergReader(Blackhole blackhole) throws IOException { + try (CloseableIterable rows = + Parquet.read(Files.localInput(dataFile)) + .project(SCHEMA) + .createReaderFunc(type -> SparkParquetReaders.buildReader(SCHEMA, type)) + .build()) { + + for (InternalRow row : rows) { + blackhole.consume(row); + } + } + } + + @Benchmark + @Threads(1) + public void readUsingIcebergReaderUnsafe(Blackhole blackhole) throws IOException { + try (CloseableIterable rows = + Parquet.read(Files.localInput(dataFile)) + .project(SCHEMA) + .createReaderFunc(type -> SparkParquetReaders.buildReader(SCHEMA, type)) + .build()) { + + Iterable unsafeRows = + Iterables.transform( + rows, APPLY_PROJECTION.bind(SparkBenchmarkUtil.projection(SCHEMA, SCHEMA))::invoke); + + for (InternalRow row : unsafeRows) { + blackhole.consume(row); + } + } + } + + @Benchmark + @Threads(1) + public void readUsingSparkReader(Blackhole blackhole) throws IOException { + StructType sparkSchema = SparkSchemaUtil.convert(SCHEMA); + try (CloseableIterable rows = + Parquet.read(Files.localInput(dataFile)) + .project(SCHEMA) + .readSupport(new ParquetReadSupport()) + .set("org.apache.spark.sql.parquet.row.requested_schema", sparkSchema.json()) + .set("spark.sql.parquet.binaryAsString", "false") + .set("spark.sql.parquet.int96AsTimestamp", "false") + .set("spark.sql.caseSensitive", "false") + .set("spark.sql.parquet.fieldId.write.enabled", "false") + .set("spark.sql.parquet.inferTimestampNTZ.enabled", "false") + .set("spark.sql.legacy.parquet.nanosAsLong", "false") + .callInit() + .build()) { + + for (InternalRow row : rows) { + blackhole.consume(row); + } + } + } + + @Benchmark + @Threads(1) + public void readWithProjectionUsingIcebergReader(Blackhole blackhole) throws IOException { + try (CloseableIterable rows = + Parquet.read(Files.localInput(dataFile)) + .project(PROJECTED_SCHEMA) + .createReaderFunc(type -> SparkParquetReaders.buildReader(PROJECTED_SCHEMA, type)) + .build()) { + + for (InternalRow row : rows) { + blackhole.consume(row); + } + } + } + + @Benchmark + @Threads(1) + public void readWithProjectionUsingIcebergReaderUnsafe(Blackhole blackhole) throws IOException { + try (CloseableIterable rows = + Parquet.read(Files.localInput(dataFile)) + .project(PROJECTED_SCHEMA) + .createReaderFunc(type -> SparkParquetReaders.buildReader(PROJECTED_SCHEMA, type)) + .build()) { + + Iterable unsafeRows = + Iterables.transform( + rows, + APPLY_PROJECTION.bind( + SparkBenchmarkUtil.projection(PROJECTED_SCHEMA, PROJECTED_SCHEMA)) + ::invoke); + + for (InternalRow row : unsafeRows) { + blackhole.consume(row); + } + } + } + + @Benchmark + @Threads(1) + public void readWithProjectionUsingSparkReader(Blackhole blackhole) throws IOException { + StructType sparkSchema = SparkSchemaUtil.convert(PROJECTED_SCHEMA); + try (CloseableIterable rows = + Parquet.read(Files.localInput(dataFile)) + .project(PROJECTED_SCHEMA) + .readSupport(new ParquetReadSupport()) + .set("org.apache.spark.sql.parquet.row.requested_schema", sparkSchema.json()) + .set("spark.sql.parquet.binaryAsString", "false") + .set("spark.sql.parquet.int96AsTimestamp", "false") + .set("spark.sql.caseSensitive", "false") + .set("spark.sql.parquet.inferTimestampNTZ.enabled", "false") + .set("spark.sql.legacy.parquet.nanosAsLong", "false") + .callInit() + .build()) { + + for (InternalRow row : rows) { + blackhole.consume(row); + } + } + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetWritersFlatDataBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetWritersFlatDataBenchmark.java new file mode 100644 index 000000000000..47f0b72088f5 --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetWritersFlatDataBenchmark.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.parquet; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.File; +import java.io.IOException; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.SparkParquetWriters; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.execution.datasources.parquet.ParquetWriteSupport; +import org.apache.spark.sql.types.StructType; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Warmup; + +/** + * A benchmark that evaluates the performance of writing Parquet data with a flat schema using + * Iceberg and Spark Parquet writers. + * + *

To run this benchmark for spark-4.0: + * ./gradlew -DsparkVersions=4.0 :iceberg-spark:iceberg-spark-4.0_2.13:jmh + * -PjmhIncludeRegex=SparkParquetWritersFlatDataBenchmark + * -PjmhOutputPath=benchmark/spark-parquet-writers-flat-data-benchmark-result.txt + * + */ +@Fork(1) +@State(Scope.Benchmark) +@Warmup(iterations = 3) +@Measurement(iterations = 5) +@BenchmarkMode(Mode.SingleShotTime) +public class SparkParquetWritersFlatDataBenchmark { + + private static final Schema SCHEMA = + new Schema( + required(1, "longCol", Types.LongType.get()), + required(2, "intCol", Types.IntegerType.get()), + required(3, "floatCol", Types.FloatType.get()), + optional(4, "doubleCol", Types.DoubleType.get()), + optional(5, "decimalCol", Types.DecimalType.of(20, 5)), + optional(6, "dateCol", Types.DateType.get()), + optional(7, "timestampCol", Types.TimestampType.withZone()), + optional(8, "stringCol", Types.StringType.get())); + private static final int NUM_RECORDS = 1000000; + private Iterable rows; + private File dataFile; + + @Setup + public void setupBenchmark() throws IOException { + rows = RandomData.generateSpark(SCHEMA, NUM_RECORDS, 0L); + dataFile = File.createTempFile("parquet-flat-data-benchmark", ".parquet"); + dataFile.delete(); + } + + @TearDown(Level.Iteration) + public void tearDownBenchmark() { + if (dataFile != null) { + dataFile.delete(); + } + } + + @Benchmark + @Threads(1) + public void writeUsingIcebergWriter() throws IOException { + try (FileAppender writer = + Parquet.write(Files.localOutput(dataFile)) + .createWriterFunc( + msgType -> + SparkParquetWriters.buildWriter(SparkSchemaUtil.convert(SCHEMA), msgType)) + .schema(SCHEMA) + .build()) { + + writer.addAll(rows); + } + } + + @Benchmark + @Threads(1) + public void writeUsingSparkWriter() throws IOException { + StructType sparkSchema = SparkSchemaUtil.convert(SCHEMA); + try (FileAppender writer = + Parquet.write(Files.localOutput(dataFile)) + .writeSupport(new ParquetWriteSupport()) + .set("org.apache.spark.sql.parquet.row.attributes", sparkSchema.json()) + .set("spark.sql.parquet.writeLegacyFormat", "false") + .set("spark.sql.parquet.binaryAsString", "false") + .set("spark.sql.parquet.int96AsTimestamp", "false") + .set("spark.sql.parquet.outputTimestampType", "TIMESTAMP_MICROS") + .set("spark.sql.caseSensitive", "false") + .set("spark.sql.parquet.fieldId.write.enabled", "false") + .schema(SCHEMA) + .build()) { + + writer.addAll(rows); + } + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetWritersNestedDataBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetWritersNestedDataBenchmark.java new file mode 100644 index 000000000000..4df890d86164 --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetWritersNestedDataBenchmark.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.parquet; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.File; +import java.io.IOException; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.SparkParquetWriters; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.execution.datasources.parquet.ParquetWriteSupport; +import org.apache.spark.sql.types.StructType; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Warmup; + +/** + * A benchmark that evaluates the performance of writing nested Parquet data using Iceberg and Spark + * Parquet writers. + * + *

To run this benchmark for spark-4.0: + * ./gradlew -DsparkVersions=4.0 :iceberg-spark:iceberg-spark-4.0_2.13:jmh + * -PjmhIncludeRegex=SparkParquetWritersNestedDataBenchmark + * -PjmhOutputPath=benchmark/spark-parquet-writers-nested-data-benchmark-result.txt + * + */ +@Fork(1) +@State(Scope.Benchmark) +@Warmup(iterations = 3) +@Measurement(iterations = 5) +@BenchmarkMode(Mode.SingleShotTime) +public class SparkParquetWritersNestedDataBenchmark { + + private static final Schema SCHEMA = + new Schema( + required(0, "id", Types.LongType.get()), + optional( + 4, + "nested", + Types.StructType.of( + required(1, "col1", Types.StringType.get()), + required(2, "col2", Types.DoubleType.get()), + required(3, "col3", Types.LongType.get())))); + private static final int NUM_RECORDS = 1000000; + private Iterable rows; + private File dataFile; + + @Setup + public void setupBenchmark() throws IOException { + rows = RandomData.generateSpark(SCHEMA, NUM_RECORDS, 0L); + dataFile = File.createTempFile("parquet-nested-data-benchmark", ".parquet"); + dataFile.delete(); + } + + @TearDown(Level.Iteration) + public void tearDownBenchmark() { + if (dataFile != null) { + dataFile.delete(); + } + } + + @Benchmark + @Threads(1) + public void writeUsingIcebergWriter() throws IOException { + try (FileAppender writer = + Parquet.write(Files.localOutput(dataFile)) + .createWriterFunc( + msgType -> + SparkParquetWriters.buildWriter(SparkSchemaUtil.convert(SCHEMA), msgType)) + .schema(SCHEMA) + .build()) { + + writer.addAll(rows); + } + } + + @Benchmark + @Threads(1) + public void writeUsingSparkWriter() throws IOException { + StructType sparkSchema = SparkSchemaUtil.convert(SCHEMA); + try (FileAppender writer = + Parquet.write(Files.localOutput(dataFile)) + .writeSupport(new ParquetWriteSupport()) + .set("org.apache.spark.sql.parquet.row.attributes", sparkSchema.json()) + .set("spark.sql.parquet.writeLegacyFormat", "false") + .set("spark.sql.parquet.binaryAsString", "false") + .set("spark.sql.parquet.int96AsTimestamp", "false") + .set("spark.sql.parquet.outputTimestampType", "TIMESTAMP_MICROS") + .set("spark.sql.caseSensitive", "false") + .set("spark.sql.parquet.fieldId.write.enabled", "false") + .schema(SCHEMA) + .build()) { + + writer.addAll(rows); + } + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/Action.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/Action.java new file mode 100644 index 000000000000..0dbf07285060 --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/Action.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +@FunctionalInterface +public interface Action { + void invoke(); +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/DVReaderBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/DVReaderBenchmark.java new file mode 100644 index 000000000000..c6794e43c636 --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/DVReaderBenchmark.java @@ -0,0 +1,267 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import com.google.errorprone.annotations.FormatMethod; +import com.google.errorprone.annotations.FormatString; +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.FileGenerationUtil; +import org.apache.iceberg.Table; +import org.apache.iceberg.data.BaseDeleteLoader; +import org.apache.iceberg.data.DeleteLoader; +import org.apache.iceberg.deletes.BaseDVFileWriter; +import org.apache.iceberg.deletes.DVFileWriter; +import org.apache.iceberg.deletes.DeleteGranularity; +import org.apache.iceberg.deletes.PositionDelete; +import org.apache.iceberg.deletes.PositionDeleteIndex; +import org.apache.iceberg.io.DeleteWriteResult; +import org.apache.iceberg.io.FanoutPositionOnlyDeleteWriter; +import org.apache.iceberg.io.OutputFileFactory; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.iceberg.util.ContentFileUtil; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.apache.spark.unsafe.types.UTF8String; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Timeout; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +/** + * A benchmark that compares the performance of DV and position delete readers. + * + *

To run this benchmark for spark-3.5: + * ./gradlew -DsparkVersions=3.5 :iceberg-spark:iceberg-spark-3.5_2.12:jmh + * -PjmhIncludeRegex=DVReaderBenchmark + * -PjmhOutputPath=benchmark/iceberg-dv-reader-benchmark-result.txt + * + */ +@Fork(1) +@State(Scope.Benchmark) +@Warmup(iterations = 3) +@Measurement(iterations = 15) +@Timeout(time = 20, timeUnit = TimeUnit.MINUTES) +@BenchmarkMode(Mode.SingleShotTime) +public class DVReaderBenchmark { + + private static final String TABLE_NAME = "test_table"; + private static final int DATA_FILE_RECORD_COUNT = 2_000_000; + private static final long TARGET_FILE_SIZE = Long.MAX_VALUE; + + @Param({"5", "10"}) + private int referencedDataFileCount; + + @Param({"0.01", "0.03", "0.05", "0.10", "0.2"}) + private double deletedRowsRatio; + + private final Configuration hadoopConf = new Configuration(); + private final Random random = ThreadLocalRandom.current(); + private SparkSession spark; + private Table table; + private DeleteWriteResult dvsResult; + private DeleteWriteResult fileDeletesResult; + private DeleteWriteResult partitionDeletesResult; + + @Setup + public void setupBenchmark() throws NoSuchTableException, ParseException, IOException { + setupSpark(); + initTable(); + List deletes = generatePositionDeletes(); + this.dvsResult = writeDVs(deletes); + this.fileDeletesResult = writePositionDeletes(deletes, DeleteGranularity.FILE); + this.partitionDeletesResult = writePositionDeletes(deletes, DeleteGranularity.PARTITION); + } + + @TearDown + public void tearDownBenchmark() { + dropTable(); + tearDownSpark(); + } + + @Benchmark + @Threads(1) + public void dv(Blackhole blackhole) { + DeleteLoader loader = new BaseDeleteLoader(file -> table.io().newInputFile(file), null); + DeleteFile dv = dvsResult.deleteFiles().get(0); + CharSequence dataFile = dv.referencedDataFile(); + PositionDeleteIndex index = loader.loadPositionDeletes(ImmutableList.of(dv), dataFile); + blackhole.consume(index); + } + + @Benchmark + @Threads(1) + public void fileScopedParquetDeletes(Blackhole blackhole) { + DeleteLoader loader = new BaseDeleteLoader(file -> table.io().newInputFile(file), null); + DeleteFile deleteFile = fileDeletesResult.deleteFiles().get(0); + CharSequence dataFile = ContentFileUtil.referencedDataFile(deleteFile); + PositionDeleteIndex index = loader.loadPositionDeletes(ImmutableList.of(deleteFile), dataFile); + blackhole.consume(index); + } + + @Benchmark + @Threads(1) + public void partitionScopedParquetDeletes(Blackhole blackhole) { + DeleteLoader loader = new BaseDeleteLoader(file -> table.io().newInputFile(file), null); + DeleteFile deleteFile = Iterables.getOnlyElement(partitionDeletesResult.deleteFiles()); + CharSequence dataFile = Iterables.getLast(partitionDeletesResult.referencedDataFiles()); + PositionDeleteIndex index = loader.loadPositionDeletes(ImmutableList.of(deleteFile), dataFile); + blackhole.consume(index); + } + + private FanoutPositionOnlyDeleteWriter newWriter(DeleteGranularity granularity) { + return new FanoutPositionOnlyDeleteWriter<>( + newWriterFactory(), + newFileFactory(FileFormat.PARQUET), + table.io(), + TARGET_FILE_SIZE, + granularity); + } + + private SparkFileWriterFactory newWriterFactory() { + return SparkFileWriterFactory.builderFor(table).dataFileFormat(FileFormat.PARQUET).build(); + } + + private OutputFileFactory newFileFactory(FileFormat format) { + return OutputFileFactory.builderFor(table, 1, 1).format(format).build(); + } + + private List generatePositionDeletes() { + int numDeletesPerFile = (int) (DATA_FILE_RECORD_COUNT * deletedRowsRatio); + int numDeletes = referencedDataFileCount * numDeletesPerFile; + List deletes = Lists.newArrayListWithExpectedSize(numDeletes); + + for (int pathIndex = 0; pathIndex < referencedDataFileCount; pathIndex++) { + UTF8String dataFilePath = UTF8String.fromString(generateDataFilePath()); + Set positions = generatePositions(numDeletesPerFile); + for (long pos : positions) { + deletes.add(new GenericInternalRow(new Object[] {dataFilePath, pos})); + } + } + + Collections.shuffle(deletes); + + return deletes; + } + + private DeleteWriteResult writeDVs(Iterable rows) throws IOException { + OutputFileFactory fileFactory = newFileFactory(FileFormat.PUFFIN); + DVFileWriter writer = new BaseDVFileWriter(fileFactory, path -> null); + try (DVFileWriter closableWriter = writer) { + for (InternalRow row : rows) { + String path = row.getString(0); + long pos = row.getLong(1); + closableWriter.delete(path, pos, table.spec(), null); + } + } + return writer.result(); + } + + private DeleteWriteResult writePositionDeletes( + Iterable rows, DeleteGranularity granularity) throws IOException { + FanoutPositionOnlyDeleteWriter writer = newWriter(granularity); + try (FanoutPositionOnlyDeleteWriter closableWriter = writer) { + PositionDelete positionDelete = PositionDelete.create(); + for (InternalRow row : rows) { + String path = row.getString(0); + long pos = row.getLong(1); + positionDelete.set(path, pos, null /* no row */); + closableWriter.write(positionDelete, table.spec(), null); + } + } + return writer.result(); + } + + public Set generatePositions(int numPositions) { + Set positions = Sets.newHashSet(); + + while (positions.size() < numPositions) { + long pos = random.nextInt(DATA_FILE_RECORD_COUNT); + positions.add(pos); + } + + return positions; + } + + private String generateDataFilePath() { + String fileName = FileGenerationUtil.generateFileName(); + return table.locationProvider().newDataLocation(table.spec(), null, fileName); + } + + private void setupSpark() { + this.spark = + SparkSession.builder() + .config("spark.ui.enabled", false) + .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .config("spark.sql.catalog.spark_catalog", SparkSessionCatalog.class.getName()) + .config("spark.sql.catalog.spark_catalog.type", "hadoop") + .config("spark.sql.catalog.spark_catalog.warehouse", newWarehouseDir()) + .master("local[*]") + .getOrCreate(); + } + + private void tearDownSpark() { + spark.stop(); + } + + private void initTable() throws NoSuchTableException, ParseException { + sql("CREATE TABLE %s (c1 INT, c2 INT, c3 STRING) USING iceberg", TABLE_NAME); + this.table = Spark3Util.loadIcebergTable(spark, TABLE_NAME); + } + + private void dropTable() { + sql("DROP TABLE IF EXISTS %s PURGE", TABLE_NAME); + } + + private String newWarehouseDir() { + return hadoopConf.get("hadoop.tmp.dir") + UUID.randomUUID(); + } + + @FormatMethod + private void sql(@FormatString String query, Object... args) { + spark.sql(String.format(query, args)); + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/DVWriterBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/DVWriterBenchmark.java new file mode 100644 index 000000000000..ac74fb5a109c --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/DVWriterBenchmark.java @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import com.google.errorprone.annotations.FormatMethod; +import com.google.errorprone.annotations.FormatString; +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.FileGenerationUtil; +import org.apache.iceberg.Table; +import org.apache.iceberg.deletes.BaseDVFileWriter; +import org.apache.iceberg.deletes.DVFileWriter; +import org.apache.iceberg.deletes.DeleteGranularity; +import org.apache.iceberg.deletes.PositionDelete; +import org.apache.iceberg.io.DeleteWriteResult; +import org.apache.iceberg.io.FanoutPositionOnlyDeleteWriter; +import org.apache.iceberg.io.OutputFileFactory; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.apache.spark.unsafe.types.UTF8String; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Timeout; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +/** + * A benchmark that compares the performance of DV and position delete writers. + * + *

To run this benchmark for spark-3.5: + * ./gradlew -DsparkVersions=3.5 :iceberg-spark:iceberg-spark-3.5_2.12:jmh + * -PjmhIncludeRegex=DVWriterBenchmark + * -PjmhOutputPath=benchmark/iceberg-dv-writer-benchmark-result.txt + * + */ +@Fork(1) +@State(Scope.Benchmark) +@Warmup(iterations = 3) +@Measurement(iterations = 10) +@Timeout(time = 20, timeUnit = TimeUnit.MINUTES) +@BenchmarkMode(Mode.SingleShotTime) +public class DVWriterBenchmark { + + private static final String TABLE_NAME = "test_table"; + private static final int DATA_FILE_RECORD_COUNT = 2_000_000; + private static final long TARGET_FILE_SIZE = Long.MAX_VALUE; + + @Param({"5", "10"}) + private int referencedDataFileCount; + + @Param({"0.01", "0.03", "0.05", "0.10", "0.2"}) + private double deletedRowsRatio; + + private final Configuration hadoopConf = new Configuration(); + private final Random random = ThreadLocalRandom.current(); + private SparkSession spark; + private Table table; + private Iterable positionDeletes; + + @Setup + public void setupBenchmark() throws NoSuchTableException, ParseException { + setupSpark(); + initTable(); + generatePositionDeletes(); + } + + @TearDown + public void tearDownBenchmark() { + dropTable(); + tearDownSpark(); + } + + @Benchmark + @Threads(1) + public void dv(Blackhole blackhole) throws IOException { + OutputFileFactory fileFactory = newFileFactory(FileFormat.PUFFIN); + DVFileWriter writer = new BaseDVFileWriter(fileFactory, path -> null); + + try (DVFileWriter closableWriter = writer) { + for (InternalRow row : positionDeletes) { + String path = row.getString(0); + long pos = row.getLong(1); + closableWriter.delete(path, pos, table.spec(), null); + } + } + + DeleteWriteResult result = writer.result(); + blackhole.consume(result); + } + + @Benchmark + @Threads(1) + public void fileScopedParquetDeletes(Blackhole blackhole) throws IOException { + FanoutPositionOnlyDeleteWriter writer = newWriter(DeleteGranularity.FILE); + write(writer, positionDeletes); + DeleteWriteResult result = writer.result(); + blackhole.consume(result); + } + + @Benchmark + @Threads(1) + public void partitionScopedParquetDeletes(Blackhole blackhole) throws IOException { + FanoutPositionOnlyDeleteWriter writer = newWriter(DeleteGranularity.PARTITION); + write(writer, positionDeletes); + DeleteWriteResult result = writer.result(); + blackhole.consume(result); + } + + private FanoutPositionOnlyDeleteWriter newWriter(DeleteGranularity granularity) { + return new FanoutPositionOnlyDeleteWriter<>( + newWriterFactory(), + newFileFactory(FileFormat.PARQUET), + table.io(), + TARGET_FILE_SIZE, + granularity); + } + + private DeleteWriteResult write( + FanoutPositionOnlyDeleteWriter writer, Iterable rows) + throws IOException { + + try (FanoutPositionOnlyDeleteWriter closableWriter = writer) { + PositionDelete positionDelete = PositionDelete.create(); + + for (InternalRow row : rows) { + String path = row.getString(0); + long pos = row.getLong(1); + positionDelete.set(path, pos, null /* no row */); + closableWriter.write(positionDelete, table.spec(), null); + } + } + + return writer.result(); + } + + private SparkFileWriterFactory newWriterFactory() { + return SparkFileWriterFactory.builderFor(table).dataFileFormat(FileFormat.PARQUET).build(); + } + + private OutputFileFactory newFileFactory(FileFormat format) { + return OutputFileFactory.builderFor(table, 1, 1).format(format).build(); + } + + private void generatePositionDeletes() { + int numDeletesPerFile = (int) (DATA_FILE_RECORD_COUNT * deletedRowsRatio); + int numDeletes = referencedDataFileCount * numDeletesPerFile; + List deletes = Lists.newArrayListWithExpectedSize(numDeletes); + + for (int pathIndex = 0; pathIndex < referencedDataFileCount; pathIndex++) { + UTF8String dataFilePath = UTF8String.fromString(generateDataFilePath()); + Set positions = generatePositions(numDeletesPerFile); + for (long pos : positions) { + deletes.add(new GenericInternalRow(new Object[] {dataFilePath, pos})); + } + } + + Collections.shuffle(deletes); + + this.positionDeletes = deletes; + } + + public Set generatePositions(int numPositions) { + Set positions = Sets.newHashSet(); + + while (positions.size() < numPositions) { + long pos = random.nextInt(DATA_FILE_RECORD_COUNT); + positions.add(pos); + } + + return positions; + } + + private String generateDataFilePath() { + String fileName = FileGenerationUtil.generateFileName(); + return table.locationProvider().newDataLocation(table.spec(), null, fileName); + } + + private void setupSpark() { + this.spark = + SparkSession.builder() + .config("spark.ui.enabled", false) + .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .config("spark.sql.catalog.spark_catalog", SparkSessionCatalog.class.getName()) + .config("spark.sql.catalog.spark_catalog.type", "hadoop") + .config("spark.sql.catalog.spark_catalog.warehouse", newWarehouseDir()) + .master("local[*]") + .getOrCreate(); + } + + private void tearDownSpark() { + spark.stop(); + } + + private void initTable() throws NoSuchTableException, ParseException { + sql("CREATE TABLE %s (c1 INT, c2 INT, c3 STRING) USING iceberg", TABLE_NAME); + this.table = Spark3Util.loadIcebergTable(spark, TABLE_NAME); + } + + private void dropTable() { + sql("DROP TABLE IF EXISTS %s PURGE", TABLE_NAME); + } + + private String newWarehouseDir() { + return hadoopConf.get("hadoop.tmp.dir") + UUID.randomUUID(); + } + + @FormatMethod + private void sql(@FormatString String query, Object... args) { + spark.sql(String.format(query, args)); + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceBenchmark.java new file mode 100644 index 000000000000..68c537e34a4a --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceBenchmark.java @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.util.Map; +import java.util.UUID; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.UpdateProperties; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.types.StructType; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +@Fork(1) +@State(Scope.Benchmark) +@Warmup(iterations = 3) +@Measurement(iterations = 5) +@BenchmarkMode(Mode.SingleShotTime) +public abstract class IcebergSourceBenchmark { + + private final Configuration hadoopConf = initHadoopConf(); + private final Table table = initTable(); + private SparkSession spark; + + protected abstract Configuration initHadoopConf(); + + protected final Configuration hadoopConf() { + return hadoopConf; + } + + protected abstract Table initTable(); + + protected final Table table() { + return table; + } + + protected final SparkSession spark() { + return spark; + } + + protected String newTableLocation() { + String tmpDir = hadoopConf.get("hadoop.tmp.dir"); + Path tablePath = new Path(tmpDir, "spark-iceberg-table-" + UUID.randomUUID()); + return tablePath.toString(); + } + + protected String dataLocation() { + Map properties = table.properties(); + return properties.getOrDefault( + TableProperties.WRITE_DATA_LOCATION, String.format("%s/data", table.location())); + } + + protected void cleanupFiles() throws IOException { + try (FileSystem fileSystem = FileSystem.get(hadoopConf)) { + Path dataPath = new Path(dataLocation()); + fileSystem.delete(dataPath, true); + Path tablePath = new Path(table.location()); + fileSystem.delete(tablePath, true); + } + } + + protected void setupSpark(boolean enableDictionaryEncoding) { + SparkSession.Builder builder = SparkSession.builder().config("spark.ui.enabled", false); + if (!enableDictionaryEncoding) { + builder + .config("parquet.dictionary.page.size", "1") + .config("parquet.enable.dictionary", false) + .config(TableProperties.PARQUET_DICT_SIZE_BYTES, "1"); + } + builder.master("local"); + spark = builder.getOrCreate(); + Configuration sparkHadoopConf = spark.sessionState().newHadoopConf(); + hadoopConf.forEach(entry -> sparkHadoopConf.set(entry.getKey(), entry.getValue())); + } + + protected void setupSpark() { + setupSpark(false); + } + + protected void tearDownSpark() { + spark.stop(); + } + + protected void materialize(Dataset ds) { + ds.queryExecution().toRdd().toJavaRDD().foreach(record -> {}); + } + + protected void materialize(Dataset ds, Blackhole blackhole) { + blackhole.consume(ds.queryExecution().toRdd().toJavaRDD().count()); + } + + protected void appendAsFile(Dataset ds) { + // ensure the schema is precise (including nullability) + StructType sparkSchema = SparkSchemaUtil.convert(table.schema()); + spark + .createDataFrame(ds.rdd(), sparkSchema) + .coalesce(1) + .write() + .format("iceberg") + .mode(SaveMode.Append) + .save(table.location()); + } + + protected void withSQLConf(Map conf, Action action) { + SQLConf sqlConf = SQLConf.get(); + + Map currentConfValues = Maps.newHashMap(); + conf.keySet() + .forEach( + confKey -> { + if (sqlConf.contains(confKey)) { + String currentConfValue = sqlConf.getConfString(confKey); + currentConfValues.put(confKey, currentConfValue); + } + }); + + conf.forEach( + (confKey, confValue) -> { + if (SQLConf.isStaticConfigKey(confKey)) { + throw new RuntimeException("Cannot modify the value of a static config: " + confKey); + } + sqlConf.setConfString(confKey, confValue); + }); + + try { + action.invoke(); + } finally { + conf.forEach( + (confKey, confValue) -> { + if (currentConfValues.containsKey(confKey)) { + sqlConf.setConfString(confKey, currentConfValues.get(confKey)); + } else { + sqlConf.unsetConf(confKey); + } + }); + } + } + + protected void withTableProperties(Map props, Action action) { + Map tableProps = table.properties(); + Map currentPropValues = Maps.newHashMap(); + props + .keySet() + .forEach( + propKey -> { + if (tableProps.containsKey(propKey)) { + String currentPropValue = tableProps.get(propKey); + currentPropValues.put(propKey, currentPropValue); + } + }); + + UpdateProperties updateProperties = table.updateProperties(); + props.forEach(updateProperties::set); + updateProperties.commit(); + + try { + action.invoke(); + } finally { + UpdateProperties restoreProperties = table.updateProperties(); + props.forEach( + (propKey, propValue) -> { + if (currentPropValues.containsKey(propKey)) { + restoreProperties.set(propKey, currentPropValues.get(propKey)); + } else { + restoreProperties.remove(propKey); + } + }); + restoreProperties.commit(); + } + } + + protected FileFormat fileFormat() { + throw new UnsupportedOperationException("Unsupported file format"); + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceDeleteBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceDeleteBenchmark.java new file mode 100644 index 000000000000..e42707bf102b --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceDeleteBenchmark.java @@ -0,0 +1,340 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.TableProperties.PARQUET_VECTORIZATION_ENABLED; +import static org.apache.iceberg.TableProperties.SPLIT_OPEN_FILE_COST; +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.apache.spark.sql.functions.current_date; +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.expr; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ThreadLocalRandom; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.RowDelta; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.deletes.PositionDelete; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.ClusteredEqualityDeleteWriter; +import org.apache.iceberg.io.ClusteredPositionDeleteWriter; +import org.apache.iceberg.io.OutputFileFactory; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.infra.Blackhole; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public abstract class IcebergSourceDeleteBenchmark extends IcebergSourceBenchmark { + private static final Logger LOG = LoggerFactory.getLogger(IcebergSourceDeleteBenchmark.class); + private static final long TARGET_FILE_SIZE_IN_BYTES = 512L * 1024 * 1024; + + protected static final int NUM_FILES = 1; + protected static final int NUM_ROWS = 10 * 1000 * 1000; + + @Setup + public void setupBenchmark() throws IOException { + setupSpark(); + appendData(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void readIceberg(Blackhole blackhole) { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + tableProperties.put(PARQUET_VECTORIZATION_ENABLED, "false"); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation); + materialize(df, blackhole); + }); + } + + @Benchmark + @Threads(1) + public void readIcebergWithIsDeletedColumn(Blackhole blackhole) { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + tableProperties.put(PARQUET_VECTORIZATION_ENABLED, "false"); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).filter("_deleted = false"); + materialize(df, blackhole); + }); + } + + @Benchmark + @Threads(1) + public void readDeletedRows(Blackhole blackhole) { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + tableProperties.put(PARQUET_VECTORIZATION_ENABLED, "false"); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).filter("_deleted = true"); + materialize(df, blackhole); + }); + } + + @Benchmark + @Threads(1) + public void readIcebergVectorized(Blackhole blackhole) { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + tableProperties.put(PARQUET_VECTORIZATION_ENABLED, "true"); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation); + materialize(df, blackhole); + }); + } + + @Benchmark + @Threads(1) + public void readIcebergWithIsDeletedColumnVectorized(Blackhole blackhole) { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + tableProperties.put(PARQUET_VECTORIZATION_ENABLED, "true"); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).filter("_deleted = false"); + materialize(df, blackhole); + }); + } + + @Benchmark + @Threads(1) + public void readDeletedRowsVectorized(Blackhole blackhole) { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + tableProperties.put(PARQUET_VECTORIZATION_ENABLED, "true"); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).filter("_deleted = true"); + materialize(df, blackhole); + }); + } + + protected abstract void appendData() throws IOException; + + protected void writeData(int fileNum) { + Dataset df = + spark() + .range(NUM_ROWS) + .withColumnRenamed("id", "longCol") + .withColumn("intCol", expr("CAST(MOD(longCol, 2147483647) AS INT)")) + .withColumn("floatCol", expr("CAST(longCol AS FLOAT)")) + .withColumn("doubleCol", expr("CAST(longCol AS DOUBLE)")) + .withColumn("dateCol", date_add(current_date(), fileNum)) + .withColumn("timestampCol", expr("TO_TIMESTAMP(dateCol)")) + .withColumn("stringCol", expr("CAST(dateCol AS STRING)")); + appendAsFile(df); + } + + @Override + protected Table initTable() { + Schema schema = + new Schema( + required(1, "longCol", Types.LongType.get()), + required(2, "intCol", Types.IntegerType.get()), + required(3, "floatCol", Types.FloatType.get()), + optional(4, "doubleCol", Types.DoubleType.get()), + optional(6, "dateCol", Types.DateType.get()), + optional(7, "timestampCol", Types.TimestampType.withZone()), + optional(8, "stringCol", Types.StringType.get())); + PartitionSpec partitionSpec = PartitionSpec.unpartitioned(); + HadoopTables tables = new HadoopTables(hadoopConf()); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.METADATA_COMPRESSION, "gzip"); + properties.put(TableProperties.FORMAT_VERSION, "2"); + return tables.create(schema, partitionSpec, properties, newTableLocation()); + } + + @Override + protected Configuration initHadoopConf() { + return new Configuration(); + } + + protected void writePosDeletes(CharSequence path, long numRows, double percentage) + throws IOException { + writePosDeletes(path, numRows, percentage, 1); + } + + protected void writePosDeletes( + CharSequence path, long numRows, double percentage, int numDeleteFile) throws IOException { + writePosDeletesWithNoise(path, numRows, percentage, 0, numDeleteFile); + } + + protected void writePosDeletesWithNoise( + CharSequence path, long numRows, double percentage, int numNoise, int numDeleteFile) + throws IOException { + Set deletedPos = Sets.newHashSet(); + while (deletedPos.size() < numRows * percentage) { + deletedPos.add(ThreadLocalRandom.current().nextLong(numRows)); + } + LOG.info("pos delete row count: {}, num of delete files: {}", deletedPos.size(), numDeleteFile); + + int partitionSize = (int) (numRows * percentage) / numDeleteFile; + Iterable> sets = Iterables.partition(deletedPos, partitionSize); + for (List item : sets) { + writePosDeletes(path, item, numNoise); + } + } + + protected void writePosDeletes(CharSequence path, List deletedPos, int numNoise) + throws IOException { + OutputFileFactory fileFactory = newFileFactory(); + SparkFileWriterFactory writerFactory = + SparkFileWriterFactory.builderFor(table()).dataFileFormat(fileFormat()).build(); + + ClusteredPositionDeleteWriter writer = + new ClusteredPositionDeleteWriter<>( + writerFactory, fileFactory, table().io(), TARGET_FILE_SIZE_IN_BYTES); + + PartitionSpec unpartitionedSpec = table().specs().get(0); + + PositionDelete positionDelete = PositionDelete.create(); + try (ClusteredPositionDeleteWriter closeableWriter = writer) { + for (Long pos : deletedPos) { + positionDelete.set(path, pos, null); + closeableWriter.write(positionDelete, unpartitionedSpec, null); + for (int i = 0; i < numNoise; i++) { + positionDelete.set(noisePath(path), pos, null); + closeableWriter.write(positionDelete, unpartitionedSpec, null); + } + } + } + + RowDelta rowDelta = table().newRowDelta(); + writer.result().deleteFiles().forEach(rowDelta::addDeletes); + rowDelta.validateDeletedFiles().commit(); + } + + protected void writeEqDeletes(long numRows, double percentage) throws IOException { + Set deletedValues = Sets.newHashSet(); + while (deletedValues.size() < numRows * percentage) { + deletedValues.add(ThreadLocalRandom.current().nextLong(numRows)); + } + + List rows = Lists.newArrayList(); + for (Long value : deletedValues) { + GenericInternalRow genericInternalRow = new GenericInternalRow(7); + genericInternalRow.setLong(0, value); + genericInternalRow.setInt(1, (int) (value % Integer.MAX_VALUE)); + genericInternalRow.setFloat(2, (float) value); + genericInternalRow.setNullAt(3); + genericInternalRow.setNullAt(4); + genericInternalRow.setNullAt(5); + genericInternalRow.setNullAt(6); + rows.add(genericInternalRow); + } + LOG.info("Num of equality deleted rows: {}", rows.size()); + + writeEqDeletes(rows); + } + + private void writeEqDeletes(List rows) throws IOException { + int equalityFieldId = table().schema().findField("longCol").fieldId(); + + OutputFileFactory fileFactory = newFileFactory(); + SparkFileWriterFactory writerFactory = + SparkFileWriterFactory.builderFor(table()) + .dataFileFormat(fileFormat()) + .equalityDeleteRowSchema(table().schema()) + .equalityFieldIds(new int[] {equalityFieldId}) + .build(); + + ClusteredEqualityDeleteWriter writer = + new ClusteredEqualityDeleteWriter<>( + writerFactory, fileFactory, table().io(), TARGET_FILE_SIZE_IN_BYTES); + + PartitionSpec unpartitionedSpec = table().specs().get(0); + try (ClusteredEqualityDeleteWriter closeableWriter = writer) { + for (InternalRow row : rows) { + closeableWriter.write(row, unpartitionedSpec, null); + } + } + + RowDelta rowDelta = table().newRowDelta(); + LOG.info("Num of Delete File: {}", writer.result().deleteFiles().size()); + writer.result().deleteFiles().forEach(rowDelta::addDeletes); + rowDelta.validateDeletedFiles().commit(); + } + + private OutputFileFactory newFileFactory() { + return OutputFileFactory.builderFor(table(), 1, 1).format(fileFormat()).build(); + } + + private CharSequence noisePath(CharSequence path) { + // assume the data file name would be something like + // "00000-0-30da64e0-56b5-4743-a11b-3188a1695bf7-00001.parquet" + // so the dataFileSuffixLen is the UUID string length + length of "-00001.parquet", which is 36 + // + 14 = 60. It's OK + // to be not accurate here. + int dataFileSuffixLen = 60; + UUID uuid = UUID.randomUUID(); + if (path.length() > dataFileSuffixLen) { + return path.subSequence(0, dataFileSuffixLen) + uuid.toString(); + } else { + return uuid.toString(); + } + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceFlatDataBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceFlatDataBenchmark.java new file mode 100644 index 000000000000..59e6230350d9 --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceFlatDataBenchmark.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Types; + +public abstract class IcebergSourceFlatDataBenchmark extends IcebergSourceBenchmark { + + @Override + protected Configuration initHadoopConf() { + return new Configuration(); + } + + @Override + protected final Table initTable() { + Schema schema = + new Schema( + required(1, "longCol", Types.LongType.get()), + required(2, "intCol", Types.IntegerType.get()), + required(3, "floatCol", Types.FloatType.get()), + optional(4, "doubleCol", Types.DoubleType.get()), + optional(5, "decimalCol", Types.DecimalType.of(20, 5)), + optional(6, "dateCol", Types.DateType.get()), + optional(7, "timestampCol", Types.TimestampType.withZone()), + optional(8, "stringCol", Types.StringType.get())); + PartitionSpec partitionSpec = PartitionSpec.unpartitioned(); + HadoopTables tables = new HadoopTables(hadoopConf()); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.METADATA_COMPRESSION, "gzip"); + return tables.create(schema, partitionSpec, properties, newTableLocation()); + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceNestedDataBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceNestedDataBenchmark.java new file mode 100644 index 000000000000..a1c61b9b4de0 --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceNestedDataBenchmark.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Types; + +public abstract class IcebergSourceNestedDataBenchmark extends IcebergSourceBenchmark { + + @Override + protected Configuration initHadoopConf() { + return new Configuration(); + } + + @Override + protected final Table initTable() { + Schema schema = + new Schema( + required(0, "id", Types.LongType.get()), + optional( + 4, + "nested", + Types.StructType.of( + required(1, "col1", Types.StringType.get()), + required(2, "col2", Types.DoubleType.get()), + required(3, "col3", Types.LongType.get())))); + PartitionSpec partitionSpec = PartitionSpec.unpartitioned(); + HadoopTables tables = new HadoopTables(hadoopConf()); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.METADATA_COMPRESSION, "gzip"); + return tables.create(schema, partitionSpec, properties, newTableLocation()); + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceNestedListDataBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceNestedListDataBenchmark.java new file mode 100644 index 000000000000..f68b587735dd --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceNestedListDataBenchmark.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Types; + +public abstract class IcebergSourceNestedListDataBenchmark extends IcebergSourceBenchmark { + + @Override + protected Configuration initHadoopConf() { + return new Configuration(); + } + + @Override + protected final Table initTable() { + Schema schema = + new Schema( + required(0, "id", Types.LongType.get()), + optional( + 1, + "outerlist", + Types.ListType.ofOptional( + 2, + Types.StructType.of( + required( + 3, + "innerlist", + Types.ListType.ofRequired(4, Types.StringType.get())))))); + PartitionSpec partitionSpec = PartitionSpec.unpartitioned(); + HadoopTables tables = new HadoopTables(hadoopConf()); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.METADATA_COMPRESSION, "gzip"); + return tables.create(schema, partitionSpec, properties, newTableLocation()); + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/WritersBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/WritersBenchmark.java new file mode 100644 index 000000000000..963159fe4ee9 --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/WritersBenchmark.java @@ -0,0 +1,490 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.IOException; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.PartitionKey; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.deletes.DeleteGranularity; +import org.apache.iceberg.deletes.PositionDelete; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.ClusteredDataWriter; +import org.apache.iceberg.io.ClusteredEqualityDeleteWriter; +import org.apache.iceberg.io.ClusteredPositionDeleteWriter; +import org.apache.iceberg.io.FanoutDataWriter; +import org.apache.iceberg.io.FanoutPositionOnlyDeleteWriter; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.OutputFileFactory; +import org.apache.iceberg.io.TaskWriter; +import org.apache.iceberg.io.UnpartitionedWriter; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.transforms.Transform; +import org.apache.iceberg.transforms.Transforms; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.infra.Blackhole; + +public abstract class WritersBenchmark extends IcebergSourceBenchmark { + + private static final int NUM_ROWS = 2500000; + private static final int NUM_DATA_FILES_PER_POSITION_DELETE_FILE = 100; + private static final int NUM_DELETED_POSITIONS_PER_DATA_FILE = 50_000; + private static final int DELETE_POSITION_STEP = 10; + private static final long TARGET_FILE_SIZE_IN_BYTES = 50L * 1024 * 1024; + + private static final Schema SCHEMA = + new Schema( + required(1, "longCol", Types.LongType.get()), + required(2, "intCol", Types.IntegerType.get()), + required(3, "floatCol", Types.FloatType.get()), + optional(4, "doubleCol", Types.DoubleType.get()), + optional(5, "decimalCol", Types.DecimalType.of(20, 5)), + optional(6, "timestampCol", Types.TimestampType.withZone()), + optional(7, "stringCol", Types.StringType.get())); + + private Iterable rows; + private Iterable positionDeleteRows; + private Iterable shuffledPositionDeleteRows; + private PartitionSpec unpartitionedSpec; + private PartitionSpec partitionedSpec; + + @Override + protected abstract FileFormat fileFormat(); + + @Setup + public void setupBenchmark() { + setupSpark(); + + List data = Lists.newArrayList(RandomData.generateSpark(SCHEMA, NUM_ROWS, 0L)); + Transform transform = Transforms.bucket(32); + data.sort( + Comparator.comparingInt( + row -> transform.bind(Types.IntegerType.get()).apply(row.getInt(1)))); + this.rows = data; + + this.positionDeleteRows = generatePositionDeletes(false /* no shuffle */); + this.shuffledPositionDeleteRows = generatePositionDeletes(true /* shuffle */); + + this.unpartitionedSpec = table().specs().get(0); + Preconditions.checkArgument(unpartitionedSpec.isUnpartitioned()); + this.partitionedSpec = table().specs().get(1); + } + + private Iterable generatePositionDeletes(boolean shuffle) { + int numDeletes = NUM_DATA_FILES_PER_POSITION_DELETE_FILE * NUM_DELETED_POSITIONS_PER_DATA_FILE; + List deletes = Lists.newArrayListWithExpectedSize(numDeletes); + + for (int pathIndex = 0; pathIndex < NUM_DATA_FILES_PER_POSITION_DELETE_FILE; pathIndex++) { + UTF8String path = UTF8String.fromString("path/to/position/delete/file/" + UUID.randomUUID()); + for (long pos = 0; pos < NUM_DELETED_POSITIONS_PER_DATA_FILE; pos++) { + deletes.add(new GenericInternalRow(new Object[] {path, pos * DELETE_POSITION_STEP})); + } + } + + if (shuffle) { + Collections.shuffle(deletes); + } + + return deletes; + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Override + protected Configuration initHadoopConf() { + return new Configuration(); + } + + @Override + protected final Table initTable() { + HadoopTables tables = new HadoopTables(hadoopConf()); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map properties = Maps.newHashMap(); + Table table = tables.create(SCHEMA, spec, properties, newTableLocation()); + + // add a partitioned spec to the table + table.updateSpec().addField(Expressions.bucket("intCol", 32)).commit(); + + return table; + } + + @Benchmark + @Threads(1) + public void writeUnpartitionedClusteredDataWriter(Blackhole blackhole) throws IOException { + FileIO io = table().io(); + + OutputFileFactory fileFactory = newFileFactory(); + SparkFileWriterFactory writerFactory = + SparkFileWriterFactory.builderFor(table()) + .dataFileFormat(fileFormat()) + .dataSchema(table().schema()) + .build(); + + ClusteredDataWriter writer = + new ClusteredDataWriter<>(writerFactory, fileFactory, io, TARGET_FILE_SIZE_IN_BYTES); + + try (ClusteredDataWriter closeableWriter = writer) { + for (InternalRow row : rows) { + closeableWriter.write(row, unpartitionedSpec, null); + } + } + + blackhole.consume(writer); + } + + @Benchmark + @Threads(1) + public void writeUnpartitionedLegacyDataWriter(Blackhole blackhole) throws IOException { + FileIO io = table().io(); + + OutputFileFactory fileFactory = newFileFactory(); + + Schema writeSchema = table().schema(); + StructType sparkWriteType = SparkSchemaUtil.convert(writeSchema); + SparkAppenderFactory appenders = + SparkAppenderFactory.builderFor(table(), writeSchema, sparkWriteType) + .spec(unpartitionedSpec) + .build(); + + TaskWriter writer = + new UnpartitionedWriter<>( + unpartitionedSpec, fileFormat(), appenders, fileFactory, io, TARGET_FILE_SIZE_IN_BYTES); + + try (TaskWriter closableWriter = writer) { + for (InternalRow row : rows) { + closableWriter.write(row); + } + } + + blackhole.consume(writer.complete()); + } + + @Benchmark + @Threads(1) + public void writePartitionedClusteredDataWriter(Blackhole blackhole) throws IOException { + FileIO io = table().io(); + + OutputFileFactory fileFactory = newFileFactory(); + SparkFileWriterFactory writerFactory = + SparkFileWriterFactory.builderFor(table()) + .dataFileFormat(fileFormat()) + .dataSchema(table().schema()) + .build(); + + ClusteredDataWriter writer = + new ClusteredDataWriter<>(writerFactory, fileFactory, io, TARGET_FILE_SIZE_IN_BYTES); + + PartitionKey partitionKey = new PartitionKey(partitionedSpec, table().schema()); + StructType dataSparkType = SparkSchemaUtil.convert(table().schema()); + InternalRowWrapper internalRowWrapper = + new InternalRowWrapper(dataSparkType, table().schema().asStruct()); + + try (ClusteredDataWriter closeableWriter = writer) { + for (InternalRow row : rows) { + partitionKey.partition(internalRowWrapper.wrap(row)); + closeableWriter.write(row, partitionedSpec, partitionKey); + } + } + + blackhole.consume(writer); + } + + @Benchmark + @Threads(1) + public void writePartitionedLegacyDataWriter(Blackhole blackhole) throws IOException { + FileIO io = table().io(); + + OutputFileFactory fileFactory = newFileFactory(); + + Schema writeSchema = table().schema(); + StructType sparkWriteType = SparkSchemaUtil.convert(writeSchema); + SparkAppenderFactory appenders = + SparkAppenderFactory.builderFor(table(), writeSchema, sparkWriteType) + .spec(partitionedSpec) + .build(); + + TaskWriter writer = + new SparkPartitionedWriter( + partitionedSpec, + fileFormat(), + appenders, + fileFactory, + io, + TARGET_FILE_SIZE_IN_BYTES, + writeSchema, + sparkWriteType); + + try (TaskWriter closableWriter = writer) { + for (InternalRow row : rows) { + closableWriter.write(row); + } + } + + blackhole.consume(writer.complete()); + } + + @Benchmark + @Threads(1) + public void writePartitionedFanoutDataWriter(Blackhole blackhole) throws IOException { + FileIO io = table().io(); + + OutputFileFactory fileFactory = newFileFactory(); + SparkFileWriterFactory writerFactory = + SparkFileWriterFactory.builderFor(table()) + .dataFileFormat(fileFormat()) + .dataSchema(table().schema()) + .build(); + + FanoutDataWriter writer = + new FanoutDataWriter<>(writerFactory, fileFactory, io, TARGET_FILE_SIZE_IN_BYTES); + + PartitionKey partitionKey = new PartitionKey(partitionedSpec, table().schema()); + StructType dataSparkType = SparkSchemaUtil.convert(table().schema()); + InternalRowWrapper internalRowWrapper = + new InternalRowWrapper(dataSparkType, table().schema().asStruct()); + + try (FanoutDataWriter closeableWriter = writer) { + for (InternalRow row : rows) { + partitionKey.partition(internalRowWrapper.wrap(row)); + closeableWriter.write(row, partitionedSpec, partitionKey); + } + } + + blackhole.consume(writer); + } + + @Benchmark + @Threads(1) + public void writePartitionedLegacyFanoutDataWriter(Blackhole blackhole) throws IOException { + FileIO io = table().io(); + + OutputFileFactory fileFactory = newFileFactory(); + + Schema writeSchema = table().schema(); + StructType sparkWriteType = SparkSchemaUtil.convert(writeSchema); + SparkAppenderFactory appenders = + SparkAppenderFactory.builderFor(table(), writeSchema, sparkWriteType) + .spec(partitionedSpec) + .build(); + + TaskWriter writer = + new SparkPartitionedFanoutWriter( + partitionedSpec, + fileFormat(), + appenders, + fileFactory, + io, + TARGET_FILE_SIZE_IN_BYTES, + writeSchema, + sparkWriteType); + + try (TaskWriter closableWriter = writer) { + for (InternalRow row : rows) { + closableWriter.write(row); + } + } + + blackhole.consume(writer.complete()); + } + + @Benchmark + @Threads(1) + public void writePartitionedClusteredEqualityDeleteWriter(Blackhole blackhole) + throws IOException { + FileIO io = table().io(); + + int equalityFieldId = table().schema().findField("longCol").fieldId(); + + OutputFileFactory fileFactory = newFileFactory(); + SparkFileWriterFactory writerFactory = + SparkFileWriterFactory.builderFor(table()) + .dataFileFormat(fileFormat()) + .equalityDeleteRowSchema(table().schema()) + .equalityFieldIds(new int[] {equalityFieldId}) + .build(); + + ClusteredEqualityDeleteWriter writer = + new ClusteredEqualityDeleteWriter<>( + writerFactory, fileFactory, io, TARGET_FILE_SIZE_IN_BYTES); + + PartitionKey partitionKey = new PartitionKey(partitionedSpec, table().schema()); + StructType deleteSparkType = SparkSchemaUtil.convert(table().schema()); + InternalRowWrapper internalRowWrapper = + new InternalRowWrapper(deleteSparkType, table().schema().asStruct()); + + try (ClusteredEqualityDeleteWriter closeableWriter = writer) { + for (InternalRow row : rows) { + partitionKey.partition(internalRowWrapper.wrap(row)); + closeableWriter.write(row, partitionedSpec, partitionKey); + } + } + + blackhole.consume(writer); + } + + @Benchmark + @Threads(1) + public void writeUnpartitionedClusteredPositionDeleteWriterPartitionGranularity( + Blackhole blackhole) throws IOException { + writeUnpartitionedClusteredPositionDeleteWriter(blackhole, DeleteGranularity.PARTITION); + } + + @Benchmark + @Threads(1) + public void writeUnpartitionedClusteredPositionDeleteWriterFileGranularity(Blackhole blackhole) + throws IOException { + writeUnpartitionedClusteredPositionDeleteWriter(blackhole, DeleteGranularity.FILE); + } + + private void writeUnpartitionedClusteredPositionDeleteWriter( + Blackhole blackhole, DeleteGranularity deleteGranularity) throws IOException { + FileIO io = table().io(); + + OutputFileFactory fileFactory = newFileFactory(); + SparkFileWriterFactory writerFactory = + SparkFileWriterFactory.builderFor(table()).dataFileFormat(fileFormat()).build(); + + ClusteredPositionDeleteWriter writer = + new ClusteredPositionDeleteWriter<>( + writerFactory, fileFactory, io, TARGET_FILE_SIZE_IN_BYTES, deleteGranularity); + + PositionDelete positionDelete = PositionDelete.create(); + try (ClusteredPositionDeleteWriter closeableWriter = writer) { + for (InternalRow row : positionDeleteRows) { + String path = row.getString(0); + long pos = row.getLong(1); + positionDelete.set(path, pos, null); + closeableWriter.write(positionDelete, unpartitionedSpec, null); + } + } + + blackhole.consume(writer); + } + + @Benchmark + @Threads(1) + public void writeUnpartitionedFanoutPositionDeleteWriterPartitionGranularity(Blackhole blackhole) + throws IOException { + writeUnpartitionedFanoutPositionDeleteWriterPartition(blackhole, DeleteGranularity.PARTITION); + } + + @Benchmark + @Threads(1) + public void writeUnpartitionedFanoutPositionDeleteWriterFileGranularity(Blackhole blackhole) + throws IOException { + writeUnpartitionedFanoutPositionDeleteWriterPartition(blackhole, DeleteGranularity.FILE); + } + + private void writeUnpartitionedFanoutPositionDeleteWriterPartition( + Blackhole blackhole, DeleteGranularity deleteGranularity) throws IOException { + FileIO io = table().io(); + + OutputFileFactory fileFactory = newFileFactory(); + SparkFileWriterFactory writerFactory = + SparkFileWriterFactory.builderFor(table()).dataFileFormat(fileFormat()).build(); + + FanoutPositionOnlyDeleteWriter writer = + new FanoutPositionOnlyDeleteWriter<>( + writerFactory, fileFactory, io, TARGET_FILE_SIZE_IN_BYTES, deleteGranularity); + + PositionDelete positionDelete = PositionDelete.create(); + try (FanoutPositionOnlyDeleteWriter closeableWriter = writer) { + for (InternalRow row : positionDeleteRows) { + String path = row.getString(0); + long pos = row.getLong(1); + positionDelete.set(path, pos, null); + closeableWriter.write(positionDelete, unpartitionedSpec, null); + } + } + + blackhole.consume(writer); + } + + @Benchmark + @Threads(1) + public void writeUnpartitionedFanoutPositionDeleteWriterShuffledPartitionGranularity( + Blackhole blackhole) throws IOException { + writeUnpartitionedFanoutPositionDeleteWriterShuffled(blackhole, DeleteGranularity.PARTITION); + } + + @Benchmark + @Threads(1) + public void writeUnpartitionedFanoutPositionDeleteWriterShuffledFileGranularity( + Blackhole blackhole) throws IOException { + writeUnpartitionedFanoutPositionDeleteWriterShuffled(blackhole, DeleteGranularity.FILE); + } + + private void writeUnpartitionedFanoutPositionDeleteWriterShuffled( + Blackhole blackhole, DeleteGranularity deleteGranularity) throws IOException { + + FileIO io = table().io(); + + OutputFileFactory fileFactory = newFileFactory(); + SparkFileWriterFactory writerFactory = + SparkFileWriterFactory.builderFor(table()).dataFileFormat(fileFormat()).build(); + + FanoutPositionOnlyDeleteWriter writer = + new FanoutPositionOnlyDeleteWriter<>( + writerFactory, fileFactory, io, TARGET_FILE_SIZE_IN_BYTES, deleteGranularity); + + PositionDelete positionDelete = PositionDelete.create(); + try (FanoutPositionOnlyDeleteWriter closeableWriter = writer) { + for (InternalRow row : shuffledPositionDeleteRows) { + String path = row.getString(0); + long pos = row.getLong(1); + positionDelete.set(path, pos, null); + closeableWriter.write(positionDelete, unpartitionedSpec, null); + } + } + + blackhole.consume(writer); + } + + private OutputFileFactory newFileFactory() { + return OutputFileFactory.builderFor(table(), 1, 1).format(fileFormat()).build(); + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/avro/AvroWritersBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/avro/AvroWritersBenchmark.java new file mode 100644 index 000000000000..4dcd58c0c4d0 --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/avro/AvroWritersBenchmark.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.avro; + +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.spark.source.WritersBenchmark; + +/** + * A benchmark that evaluates the performance of various Iceberg writers for Avro data. + * + *

To run this benchmark for spark-4.0: + * ./gradlew -DsparkVersions=4.0 :iceberg-spark:iceberg-spark-4.0_2.13:jmh + * -PjmhIncludeRegex=AvroWritersBenchmark + * -PjmhOutputPath=benchmark/avro-writers-benchmark-result.txt + * + */ +public class AvroWritersBenchmark extends WritersBenchmark { + + @Override + protected FileFormat fileFormat() { + return FileFormat.AVRO; + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/avro/IcebergSourceFlatAvroDataReadBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/avro/IcebergSourceFlatAvroDataReadBenchmark.java new file mode 100644 index 000000000000..f0297f644a52 --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/avro/IcebergSourceFlatAvroDataReadBenchmark.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.avro; + +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.SPLIT_OPEN_FILE_COST; +import static org.apache.spark.sql.functions.current_date; +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.expr; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceFlatDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +/** + * A benchmark that evaluates the performance of reading Avro data with a flat schema using Iceberg + * and the built-in file source in Spark. + * + *

To run this benchmark for spark-4.0: + * ./gradlew -DsparkVersions=4.0 :iceberg-spark:iceberg-spark-4.0_2.13:jmh + * -PjmhIncludeRegex=IcebergSourceFlatAvroDataReadBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-flat-avro-data-read-benchmark-result.txt + * + */ +public class IcebergSourceFlatAvroDataReadBenchmark extends IcebergSourceFlatDataBenchmark { + + private static final int NUM_FILES = 10; + private static final int NUM_ROWS = 1000000; + + @Setup + public void setupBenchmark() { + setupSpark(); + appendData(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void readIceberg() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFileSource() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().format("avro").load(dataLocation()); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionIceberg() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation).select("longCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionFileSource() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().format("avro").load(dataLocation()).select("longCol"); + materialize(df); + }); + } + + private void appendData() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(DEFAULT_FILE_FORMAT, "avro"); + withTableProperties( + tableProperties, + () -> { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + Dataset df = + spark() + .range(NUM_ROWS) + .withColumnRenamed("id", "longCol") + .withColumn("intCol", expr("CAST(longCol AS INT)")) + .withColumn("floatCol", expr("CAST(longCol AS FLOAT)")) + .withColumn("doubleCol", expr("CAST(longCol AS DOUBLE)")) + .withColumn("decimalCol", expr("CAST(longCol AS DECIMAL(20, 5))")) + .withColumn("dateCol", date_add(current_date(), fileNum)) + .withColumn("timestampCol", expr("TO_TIMESTAMP(dateCol)")) + .withColumn("stringCol", expr("CAST(dateCol AS STRING)")); + appendAsFile(df); + } + }); + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/avro/IcebergSourceNestedAvroDataReadBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/avro/IcebergSourceNestedAvroDataReadBenchmark.java new file mode 100644 index 000000000000..00d06566fbaa --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/avro/IcebergSourceNestedAvroDataReadBenchmark.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.avro; + +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.SPLIT_OPEN_FILE_COST; +import static org.apache.spark.sql.functions.expr; +import static org.apache.spark.sql.functions.lit; +import static org.apache.spark.sql.functions.struct; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceNestedDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +/** + * A benchmark that evaluates the performance of reading Avro data with a flat schema using Iceberg + * and the built-in file source in Spark. + * + *

To run this benchmark for spark-4.0: + * ./gradlew -DsparkVersions=4.0 :iceberg-spark:iceberg-spark-4.0_2.13:jmh + * -PjmhIncludeRegex=IcebergSourceNestedAvroDataReadBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-nested-avro-data-read-benchmark-result.txt + * + */ +public class IcebergSourceNestedAvroDataReadBenchmark extends IcebergSourceNestedDataBenchmark { + + private static final int NUM_FILES = 10; + private static final int NUM_ROWS = 1000000; + + @Setup + public void setupBenchmark() { + setupSpark(); + appendData(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void readIceberg() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFileSource() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().format("avro").load(dataLocation()); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionIceberg() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).select("nested.col3"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionFileSource() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = + spark().read().format("avro").load(dataLocation()).select("nested.col3"); + materialize(df); + }); + } + + private void appendData() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(DEFAULT_FILE_FORMAT, "avro"); + withTableProperties( + tableProperties, + () -> { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + Dataset df = + spark() + .range(NUM_ROWS) + .withColumn( + "nested", + struct( + expr("CAST(id AS string) AS col1"), + expr("CAST(id AS double) AS col2"), + lit(fileNum).cast("long").as("col3"))); + appendAsFile(df); + } + }); + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceFlatORCDataBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceFlatORCDataBenchmark.java new file mode 100644 index 000000000000..d0fdd8915780 --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceFlatORCDataBenchmark.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.orc; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceBenchmark; +import org.apache.iceberg.types.Types; + +/** + * Same as {@link org.apache.iceberg.spark.source.IcebergSourceFlatDataBenchmark} but we disable the + * Timestamp with zone type for ORC performance tests as Spark native reader does not support ORC's + * TIMESTAMP_INSTANT type + */ +public abstract class IcebergSourceFlatORCDataBenchmark extends IcebergSourceBenchmark { + + @Override + protected Configuration initHadoopConf() { + return new Configuration(); + } + + @Override + protected final Table initTable() { + Schema schema = + new Schema( + required(1, "longCol", Types.LongType.get()), + required(2, "intCol", Types.IntegerType.get()), + required(3, "floatCol", Types.FloatType.get()), + optional(4, "doubleCol", Types.DoubleType.get()), + optional(5, "decimalCol", Types.DecimalType.of(20, 5)), + optional(6, "dateCol", Types.DateType.get()), + // Disable timestamp column for ORC performance tests as Spark native reader does not + // support ORC's + // TIMESTAMP_INSTANT type + // optional(7, "timestampCol", Types.TimestampType.withZone()), + optional(8, "stringCol", Types.StringType.get())); + PartitionSpec partitionSpec = PartitionSpec.unpartitioned(); + HadoopTables tables = new HadoopTables(hadoopConf()); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.METADATA_COMPRESSION, "gzip"); + return tables.create(schema, partitionSpec, properties, newTableLocation()); + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceFlatORCDataReadBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceFlatORCDataReadBenchmark.java new file mode 100644 index 000000000000..593fbc955703 --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceFlatORCDataReadBenchmark.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.orc; + +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.SPLIT_OPEN_FILE_COST; +import static org.apache.spark.sql.functions.current_date; +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.expr; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +/** + * A benchmark that evaluates the performance of reading ORC data with a flat schema using Iceberg + * and the built-in file source in Spark. + * + *

To run this benchmark for spark-4.0: + * ./gradlew -DsparkVersions=4.0 :iceberg-spark:iceberg-spark-4.0_2.13:jmh + * -PjmhIncludeRegex=IcebergSourceFlatORCDataReadBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-flat-orc-data-read-benchmark-result.txt + * + */ +public class IcebergSourceFlatORCDataReadBenchmark extends IcebergSourceFlatORCDataBenchmark { + + private static final int NUM_FILES = 10; + private static final int NUM_ROWS = 1000000; + + @Setup + public void setupBenchmark() { + setupSpark(); + appendData(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void readIcebergNonVectorized() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readIcebergVectorized() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = + spark() + .read() + .option(SparkReadOptions.VECTORIZATION_ENABLED, "true") + .format("iceberg") + .load(tableLocation); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFileSourceVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.ORC_VECTORIZED_READER_ENABLED().key(), "true"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().orc(dataLocation()); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFileSourceNonVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.ORC_VECTORIZED_READER_ENABLED().key(), "false"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().orc(dataLocation()); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionIcebergNonVectorized() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation).select("longCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionIcebergVectorized() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = + spark() + .read() + .option(SparkReadOptions.VECTORIZATION_ENABLED, "true") + .format("iceberg") + .load(tableLocation) + .select("longCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionFileSourceVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.ORC_VECTORIZED_READER_ENABLED().key(), "true"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().orc(dataLocation()).select("longCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionFileSourceNonVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.ORC_VECTORIZED_READER_ENABLED().key(), "false"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().orc(dataLocation()).select("longCol"); + materialize(df); + }); + } + + private void appendData() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(DEFAULT_FILE_FORMAT, "orc"); + withTableProperties( + tableProperties, + () -> { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + Dataset df = + spark() + .range(NUM_ROWS) + .withColumnRenamed("id", "longCol") + .withColumn("intCol", expr("CAST(longCol AS INT)")) + .withColumn("floatCol", expr("CAST(longCol AS FLOAT)")) + .withColumn("doubleCol", expr("CAST(longCol AS DOUBLE)")) + .withColumn("decimalCol", expr("CAST(longCol AS DECIMAL(20, 5))")) + .withColumn("dateCol", date_add(current_date(), fileNum)) + .withColumn("stringCol", expr("CAST(dateCol AS STRING)")); + appendAsFile(df); + } + }); + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceNestedListORCDataWriteBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceNestedListORCDataWriteBenchmark.java new file mode 100644 index 000000000000..0442ed02eb49 --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceNestedListORCDataWriteBenchmark.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.orc; + +import static org.apache.spark.sql.functions.array_repeat; +import static org.apache.spark.sql.functions.expr; +import static org.apache.spark.sql.functions.struct; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceNestedListDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +/** + * A benchmark that evaluates the performance of writing nested Parquet data using Iceberg and the + * built-in file source in Spark. + * + *

To run this benchmark for spark-4.0: + * ./gradlew -DsparkVersions=4.0 :iceberg-spark:iceberg-spark-4.0_2.13:jmh + * -PjmhIncludeRegex=IcebergSourceNestedListORCDataWriteBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-nested-list-orc-data-write-benchmark-result.txt + * + */ +public class IcebergSourceNestedListORCDataWriteBenchmark + extends IcebergSourceNestedListDataBenchmark { + + @Setup + public void setupBenchmark() { + setupSpark(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Param({"2000", "20000"}) + private int numRows; + + @Benchmark + @Threads(1) + public void writeIceberg() { + String tableLocation = table().location(); + benchmarkData() + .write() + .format("iceberg") + .option("write-format", "orc") + .mode(SaveMode.Append) + .save(tableLocation); + } + + @Benchmark + @Threads(1) + public void writeIcebergDictionaryOff() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put("orc.dictionary.key.threshold", "0"); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + benchmarkData() + .write() + .format("iceberg") + .option("write-format", "orc") + .mode(SaveMode.Append) + .save(tableLocation); + }); + } + + @Benchmark + @Threads(1) + public void writeFileSource() { + benchmarkData().write().mode(SaveMode.Append).orc(dataLocation()); + } + + private Dataset benchmarkData() { + return spark() + .range(numRows) + .withColumn( + "outerlist", + array_repeat(struct(expr("array_repeat(CAST(id AS string), 1000) AS innerlist")), 10)) + .coalesce(1); + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceNestedORCDataReadBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceNestedORCDataReadBenchmark.java new file mode 100644 index 000000000000..a64a23774eec --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceNestedORCDataReadBenchmark.java @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.orc; + +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.SPLIT_OPEN_FILE_COST; +import static org.apache.spark.sql.functions.expr; +import static org.apache.spark.sql.functions.lit; +import static org.apache.spark.sql.functions.struct; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.source.IcebergSourceNestedDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +/** + * A benchmark that evaluates the performance of reading ORC data with a flat schema using Iceberg + * and the built-in file source in Spark. + * + *

To run this benchmark for spark-4.0: + * ./gradlew -DsparkVersions=4.0 :iceberg-spark:iceberg-spark-4.0_2.13:jmh + * -PjmhIncludeRegex=IcebergSourceNestedORCDataReadBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-nested-orc-data-read-benchmark-result.txt + * + */ +public class IcebergSourceNestedORCDataReadBenchmark extends IcebergSourceNestedDataBenchmark { + + private static final int NUM_FILES = 10; + private static final int NUM_ROWS = 1000000; + + @Setup + public void setupBenchmark() { + setupSpark(); + appendData(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void readIcebergNonVectorized() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readIcebergVectorized() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = + spark() + .read() + .option(SparkReadOptions.VECTORIZATION_ENABLED, "true") + .format("iceberg") + .load(tableLocation); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFileSourceNonVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.ORC_VECTORIZED_READER_ENABLED().key(), "false"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().orc(dataLocation()); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionIcebergNonVectorized() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).selectExpr("nested.col3"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionIcebergVectorized() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = + spark() + .read() + .option(SparkReadOptions.VECTORIZATION_ENABLED, "true") + .format("iceberg") + .load(tableLocation) + .selectExpr("nested.col3"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionFileSourceNonVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.ORC_VECTORIZED_READER_ENABLED().key(), "false"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().orc(dataLocation()).selectExpr("nested.col3"); + materialize(df); + }); + } + + private void appendData() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(DEFAULT_FILE_FORMAT, "orc"); + withTableProperties( + tableProperties, + () -> { + for (int fileNum = 0; fileNum < NUM_FILES; fileNum++) { + Dataset df = + spark() + .range(NUM_ROWS) + .withColumn( + "nested", + struct( + expr("CAST(id AS string) AS col1"), + expr("CAST(id AS double) AS col2"), + lit(fileNum).cast("long").as("col3"))); + appendAsFile(df); + } + }); + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceFlatParquetDataFilterBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceFlatParquetDataFilterBenchmark.java new file mode 100644 index 000000000000..5b7b22f5ead7 --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceFlatParquetDataFilterBenchmark.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.parquet; + +import static org.apache.iceberg.TableProperties.SPLIT_OPEN_FILE_COST; +import static org.apache.spark.sql.functions.current_date; +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.expr; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceFlatDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +/** + * A benchmark that evaluates the file skipping capabilities in the Spark data source for Iceberg. + * + *

This class uses a dataset with a flat schema, where the records are clustered according to the + * column used in the filter predicate. + * + *

The performance is compared to the built-in file source in Spark. + * + *

To run this benchmark for spark-4.0: + * ./gradlew -DsparkVersions=4.0 :iceberg-spark:iceberg-spark-4.0_2.13:jmh + * -PjmhIncludeRegex=IcebergSourceFlatParquetDataFilterBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-flat-parquet-data-filter-benchmark-result.txt + * + */ +public class IcebergSourceFlatParquetDataFilterBenchmark extends IcebergSourceFlatDataBenchmark { + + private static final String FILTER_COND = "dateCol == date_add(current_date(), 1)"; + private static final int NUM_FILES = 500; + private static final int NUM_ROWS = 10000; + + @Setup + public void setupBenchmark() { + setupSpark(); + appendData(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void readWithFilterIceberg() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).filter(FILTER_COND); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithFilterFileSourceVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "true"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().parquet(dataLocation()).filter(FILTER_COND); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithFilterFileSourceNonVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "false"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().parquet(dataLocation()).filter(FILTER_COND); + materialize(df); + }); + } + + private void appendData() { + for (int fileNum = 1; fileNum < NUM_FILES; fileNum++) { + Dataset df = + spark() + .range(NUM_ROWS) + .withColumnRenamed("id", "longCol") + .withColumn("intCol", expr("CAST(longCol AS INT)")) + .withColumn("floatCol", expr("CAST(longCol AS FLOAT)")) + .withColumn("doubleCol", expr("CAST(longCol AS DOUBLE)")) + .withColumn("decimalCol", expr("CAST(longCol AS DECIMAL(20, 5))")) + .withColumn("dateCol", date_add(current_date(), fileNum)) + .withColumn("timestampCol", expr("TO_TIMESTAMP(dateCol)")) + .withColumn("stringCol", expr("CAST(dateCol AS STRING)")); + appendAsFile(df); + } + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceFlatParquetDataReadBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceFlatParquetDataReadBenchmark.java new file mode 100644 index 000000000000..ec1514fe4297 --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceFlatParquetDataReadBenchmark.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.parquet; + +import static org.apache.iceberg.TableProperties.SPLIT_OPEN_FILE_COST; +import static org.apache.spark.sql.functions.current_date; +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.expr; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceFlatDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +/** + * A benchmark that evaluates the performance of reading Parquet data with a flat schema using + * Iceberg and the built-in file source in Spark. + * + *

To run this benchmark for spark-4.0: + * ./gradlew -DsparkVersions=4.0 :iceberg-spark:iceberg-spark-4.0_2.13:jmh + * -PjmhIncludeRegex=IcebergSourceFlatParquetDataReadBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-flat-parquet-data-read-benchmark-result.txt + * + */ +public class IcebergSourceFlatParquetDataReadBenchmark extends IcebergSourceFlatDataBenchmark { + + private static final int NUM_FILES = 10; + private static final int NUM_ROWS = 1000000; + + @Setup + public void setupBenchmark() { + setupSpark(); + appendData(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void readIceberg() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFileSourceVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "true"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().parquet(dataLocation()); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFileSourceNonVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "false"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().parquet(dataLocation()); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionIceberg() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation).select("longCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionFileSourceVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "true"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().parquet(dataLocation()).select("longCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionFileSourceNonVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "false"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().parquet(dataLocation()).select("longCol"); + materialize(df); + }); + } + + private void appendData() { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + Dataset df = + spark() + .range(NUM_ROWS) + .withColumnRenamed("id", "longCol") + .withColumn("intCol", expr("CAST(longCol AS INT)")) + .withColumn("floatCol", expr("CAST(longCol AS FLOAT)")) + .withColumn("doubleCol", expr("CAST(longCol AS DOUBLE)")) + .withColumn("decimalCol", expr("CAST(longCol AS DECIMAL(20, 5))")) + .withColumn("dateCol", date_add(current_date(), fileNum)) + .withColumn("timestampCol", expr("TO_TIMESTAMP(dateCol)")) + .withColumn("stringCol", expr("CAST(dateCol AS STRING)")); + appendAsFile(df); + } + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceFlatParquetDataWriteBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceFlatParquetDataWriteBenchmark.java new file mode 100644 index 000000000000..787ae389ca6b --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceFlatParquetDataWriteBenchmark.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.parquet; + +import static org.apache.spark.sql.functions.expr; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceFlatDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +/** + * A benchmark that evaluates the performance of writing Parquet data with a flat schema using + * Iceberg and the built-in file source in Spark. + * + *

To run this benchmark for spark-4.0: + * ./gradlew -DsparkVersions=4.0 :iceberg-spark:iceberg-spark-4.0_2.13:jmh + * -PjmhIncludeRegex=IcebergSourceFlatParquetDataWriteBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-flat-parquet-data-write-benchmark-result.txt + * + */ +public class IcebergSourceFlatParquetDataWriteBenchmark extends IcebergSourceFlatDataBenchmark { + + private static final int NUM_ROWS = 5000000; + + @Setup + public void setupBenchmark() { + setupSpark(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void writeIceberg() { + String tableLocation = table().location(); + benchmarkData().write().format("iceberg").mode(SaveMode.Append).save(tableLocation); + } + + @Benchmark + @Threads(1) + public void writeFileSource() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_COMPRESSION().key(), "gzip"); + withSQLConf(conf, () -> benchmarkData().write().mode(SaveMode.Append).parquet(dataLocation())); + } + + private Dataset benchmarkData() { + return spark() + .range(NUM_ROWS) + .withColumnRenamed("id", "longCol") + .withColumn("intCol", expr("CAST(longCol AS INT)")) + .withColumn("floatCol", expr("CAST(longCol AS FLOAT)")) + .withColumn("doubleCol", expr("CAST(longCol AS DOUBLE)")) + .withColumn("decimalCol", expr("CAST(longCol AS DECIMAL(20, 5))")) + .withColumn("dateCol", expr("DATE_ADD(CURRENT_DATE(), (intCol % 20))")) + .withColumn("timestampCol", expr("TO_TIMESTAMP(dateCol)")) + .withColumn("stringCol", expr("CAST(dateCol AS STRING)")) + .coalesce(1); + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedListParquetDataWriteBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedListParquetDataWriteBenchmark.java new file mode 100644 index 000000000000..0d17bd3e5653 --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedListParquetDataWriteBenchmark.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.parquet; + +import static org.apache.spark.sql.functions.array_repeat; +import static org.apache.spark.sql.functions.expr; +import static org.apache.spark.sql.functions.struct; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceNestedListDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +/** + * A benchmark that evaluates the performance of writing nested Parquet data using Iceberg and the + * built-in file source in Spark. + * + *

To run this benchmark for spark-4.0: + * ./gradlew -DsparkVersions=4.0 :iceberg-spark:iceberg-spark-4.0_2.13:jmh + * -PjmhIncludeRegex=IcebergSourceNestedListParquetDataWriteBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-nested-list-parquet-data-write-benchmark-result.txt + * + */ +public class IcebergSourceNestedListParquetDataWriteBenchmark + extends IcebergSourceNestedListDataBenchmark { + + @Setup + public void setupBenchmark() { + setupSpark(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Param({"2000", "20000"}) + private int numRows; + + @Benchmark + @Threads(1) + public void writeIceberg() { + String tableLocation = table().location(); + benchmarkData().write().format("iceberg").mode(SaveMode.Append).save(tableLocation); + } + + @Benchmark + @Threads(1) + public void writeFileSource() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_COMPRESSION().key(), "gzip"); + withSQLConf(conf, () -> benchmarkData().write().mode(SaveMode.Append).parquet(dataLocation())); + } + + private Dataset benchmarkData() { + return spark() + .range(numRows) + .withColumn( + "outerlist", + array_repeat(struct(expr("array_repeat(CAST(id AS string), 1000) AS innerlist")), 10)) + .coalesce(1); + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedParquetDataFilterBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedParquetDataFilterBenchmark.java new file mode 100644 index 000000000000..a5ddd060422f --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedParquetDataFilterBenchmark.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.parquet; + +import static org.apache.iceberg.TableProperties.SPLIT_OPEN_FILE_COST; +import static org.apache.spark.sql.functions.expr; +import static org.apache.spark.sql.functions.lit; +import static org.apache.spark.sql.functions.struct; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceNestedDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +/** + * A benchmark that evaluates the file skipping capabilities in the Spark data source for Iceberg. + * + *

This class uses a dataset with nested data, where the records are clustered according to the + * column used in the filter predicate. + * + *

The performance is compared to the built-in file source in Spark. + * + *

To run this benchmark for spark-4.0: + * ./gradlew -DsparkVersions=4.0 :iceberg-spark:iceberg-spark-4.0_2.13:jmh + * -PjmhIncludeRegex=IcebergSourceNestedParquetDataFilterBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-nested-parquet-data-filter-benchmark-result.txt + * + */ +public class IcebergSourceNestedParquetDataFilterBenchmark + extends IcebergSourceNestedDataBenchmark { + + private static final String FILTER_COND = "nested.col3 == 0"; + private static final int NUM_FILES = 500; + private static final int NUM_ROWS = 10000; + + @Setup + public void setupBenchmark() { + setupSpark(); + appendData(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void readWithFilterIceberg() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).filter(FILTER_COND); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithFilterFileSourceVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "true"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().parquet(dataLocation()).filter(FILTER_COND); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithFilterFileSourceNonVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "false"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().parquet(dataLocation()).filter(FILTER_COND); + materialize(df); + }); + } + + private void appendData() { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + Dataset df = + spark() + .range(NUM_ROWS) + .withColumn( + "nested", + struct( + expr("CAST(id AS string) AS col1"), + expr("CAST(id AS double) AS col2"), + lit(fileNum).cast("long").as("col3"))); + appendAsFile(df); + } + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedParquetDataReadBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedParquetDataReadBenchmark.java new file mode 100644 index 000000000000..24e2d99902b4 --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedParquetDataReadBenchmark.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.parquet; + +import static org.apache.iceberg.TableProperties.SPLIT_OPEN_FILE_COST; +import static org.apache.spark.sql.functions.expr; +import static org.apache.spark.sql.functions.lit; +import static org.apache.spark.sql.functions.struct; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceNestedDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +/** + * A benchmark that evaluates the performance of reading nested Parquet data using Iceberg and the + * built-in file source in Spark. + * + *

To run this benchmark for spark-4.0: + * ./gradlew -DsparkVersions=4.0 :iceberg-spark:iceberg-spark-4.0_2.13:jmh + * -PjmhIncludeRegex=IcebergSourceNestedParquetDataReadBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-nested-parquet-data-read-benchmark-result.txt + * + */ +public class IcebergSourceNestedParquetDataReadBenchmark extends IcebergSourceNestedDataBenchmark { + + private static final int NUM_FILES = 10; + private static final int NUM_ROWS = 1000000; + + @Setup + public void setupBenchmark() { + setupSpark(); + appendData(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void readIceberg() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFileSourceVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "true"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().parquet(dataLocation()); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFileSourceNonVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "false"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().parquet(dataLocation()); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionIceberg() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).selectExpr("nested.col3"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionFileSourceVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "true"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + conf.put(SQLConf.NESTED_SCHEMA_PRUNING_ENABLED().key(), "true"); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().parquet(dataLocation()).selectExpr("nested.col3"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionFileSourceNonVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "false"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + conf.put(SQLConf.NESTED_SCHEMA_PRUNING_ENABLED().key(), "true"); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().parquet(dataLocation()).selectExpr("nested.col3"); + materialize(df); + }); + } + + private void appendData() { + for (int fileNum = 0; fileNum < NUM_FILES; fileNum++) { + Dataset df = + spark() + .range(NUM_ROWS) + .withColumn( + "nested", + struct( + expr("CAST(id AS string) AS col1"), + expr("CAST(id AS double) AS col2"), + lit(fileNum).cast("long").as("col3"))); + appendAsFile(df); + } + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedParquetDataWriteBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedParquetDataWriteBenchmark.java new file mode 100644 index 000000000000..eef14854c4d6 --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedParquetDataWriteBenchmark.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.parquet; + +import static org.apache.spark.sql.functions.expr; +import static org.apache.spark.sql.functions.struct; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceNestedDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +/** + * A benchmark that evaluates the performance of writing nested Parquet data using Iceberg and the + * built-in file source in Spark. + * + *

To run this benchmark for spark-4.0: + * ./gradlew -DsparkVersions=4.0 :iceberg-spark:iceberg-spark-4.0_2.13:jmh + * -PjmhIncludeRegex=IcebergSourceNestedParquetDataWriteBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-nested-parquet-data-write-benchmark-result.txt + * + */ +public class IcebergSourceNestedParquetDataWriteBenchmark extends IcebergSourceNestedDataBenchmark { + + private static final int NUM_ROWS = 5000000; + + @Setup + public void setupBenchmark() { + setupSpark(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void writeIceberg() { + String tableLocation = table().location(); + benchmarkData().write().format("iceberg").mode(SaveMode.Append).save(tableLocation); + } + + @Benchmark + @Threads(1) + public void writeFileSource() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_COMPRESSION().key(), "gzip"); + withSQLConf(conf, () -> benchmarkData().write().mode(SaveMode.Append).parquet(dataLocation())); + } + + private Dataset benchmarkData() { + return spark() + .range(NUM_ROWS) + .withColumn( + "nested", + struct( + expr("CAST(id AS string) AS col1"), + expr("CAST(id AS double) AS col2"), + expr("id AS col3"))) + .coalesce(1); + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetEqDeleteBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetEqDeleteBenchmark.java new file mode 100644 index 000000000000..3b54b448b8b5 --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetEqDeleteBenchmark.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.parquet; + +import java.io.IOException; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.spark.source.IcebergSourceDeleteBenchmark; +import org.openjdk.jmh.annotations.Param; + +/** + * A benchmark that evaluates the non-vectorized read and vectorized read with equality delete in + * the Spark data source for Iceberg. + * + *

This class uses a dataset with a flat schema. To run this benchmark for spark-4.0: + * ./gradlew -DsparkVersions=4.0 :iceberg-spark:iceberg-spark-4.0:jmh + * -PjmhIncludeRegex=IcebergSourceParquetEqDeleteBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-parquet-eq-delete-benchmark-result.txt + * + */ +public class IcebergSourceParquetEqDeleteBenchmark extends IcebergSourceDeleteBenchmark { + @Param({"0", "0.000001", "0.05", "0.25", "0.5", "1"}) + private double percentDeleteRow; + + @Override + protected void appendData() throws IOException { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + writeData(fileNum); + + if (percentDeleteRow > 0) { + // add equality deletes + table().refresh(); + writeEqDeletes(NUM_ROWS, percentDeleteRow); + } + } + } + + @Override + protected FileFormat fileFormat() { + return FileFormat.PARQUET; + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetMultiDeleteFileBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetMultiDeleteFileBenchmark.java new file mode 100644 index 000000000000..b86a02de0c50 --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetMultiDeleteFileBenchmark.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.parquet; + +import java.io.IOException; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.spark.source.IcebergSourceDeleteBenchmark; +import org.openjdk.jmh.annotations.Param; + +/** + * A benchmark that evaluates the non-vectorized read and vectorized read with pos-delete in the + * Spark data source for Iceberg. + * + *

This class uses a dataset with a flat schema. To run this benchmark for spark-4.0: + * ./gradlew -DsparkVersions=4.0 :iceberg-spark:iceberg-spark-4.0:jmh \ + * -PjmhIncludeRegex=IcebergSourceParquetMultiDeleteFileBenchmark \ + * -PjmhOutputPath=benchmark/iceberg-source-parquet-multi-delete-file-benchmark-result.txt + * + */ +public class IcebergSourceParquetMultiDeleteFileBenchmark extends IcebergSourceDeleteBenchmark { + @Param({"1", "2", "5", "10"}) + private int numDeleteFile; + + @Override + protected void appendData() throws IOException { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + writeData(fileNum); + + table().refresh(); + for (DataFile file : table().currentSnapshot().addedDataFiles(table().io())) { + writePosDeletes(file.path(), NUM_ROWS, 0.25, numDeleteFile); + } + } + } + + @Override + protected FileFormat fileFormat() { + return FileFormat.PARQUET; + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetPosDeleteBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetPosDeleteBenchmark.java new file mode 100644 index 000000000000..c8c8b4e1a8a2 --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetPosDeleteBenchmark.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.parquet; + +import java.io.IOException; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.spark.source.IcebergSourceDeleteBenchmark; +import org.openjdk.jmh.annotations.Param; + +/** + * A benchmark that evaluates the non-vectorized read and vectorized read with pos-delete in the + * Spark data source for Iceberg. + * + *

This class uses a dataset with a flat schema. To run this benchmark for spark-4.0: + * ./gradlew -DsparkVersions=4.0 :iceberg-spark:iceberg-spark-4.0:jmh + * -PjmhIncludeRegex=IcebergSourceParquetPosDeleteBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-parquet-pos-delete-benchmark-result.txt + * + */ +public class IcebergSourceParquetPosDeleteBenchmark extends IcebergSourceDeleteBenchmark { + @Param({"0", "0.000001", "0.05", "0.25", "0.5", "1"}) + private double percentDeleteRow; + + @Override + protected void appendData() throws IOException { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + writeData(fileNum); + + if (percentDeleteRow > 0) { + // add pos-deletes + table().refresh(); + for (DataFile file : table().currentSnapshot().addedDataFiles(table().io())) { + writePosDeletes(file.path(), NUM_ROWS, percentDeleteRow); + } + } + } + } + + @Override + protected FileFormat fileFormat() { + return FileFormat.PARQUET; + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetWithUnrelatedDeleteBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetWithUnrelatedDeleteBenchmark.java new file mode 100644 index 000000000000..3f2e1c22f535 --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetWithUnrelatedDeleteBenchmark.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.parquet; + +import java.io.IOException; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.spark.source.IcebergSourceDeleteBenchmark; +import org.openjdk.jmh.annotations.Param; + +/** + * A benchmark that evaluates the non-vectorized read and vectorized read with pos-delete in the + * Spark data source for Iceberg. + * + *

This class uses a dataset with a flat schema. To run this benchmark for spark-4.0: + * ./gradlew -DsparkVersions=4.0 :iceberg-spark:iceberg-spark-4.0:jmh + * -PjmhIncludeRegex=IcebergSourceParquetWithUnrelatedDeleteBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-parquet-with-unrelated-delete-benchmark-result.txt + * + */ +public class IcebergSourceParquetWithUnrelatedDeleteBenchmark extends IcebergSourceDeleteBenchmark { + private static final double PERCENT_DELETE_ROW = 0.05; + + @Param({"0", "0.05", "0.25", "0.5"}) + private double percentUnrelatedDeletes; + + @Override + protected void appendData() throws IOException { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + writeData(fileNum); + + table().refresh(); + for (DataFile file : table().currentSnapshot().addedDataFiles(table().io())) { + writePosDeletesWithNoise( + file.path(), + NUM_ROWS, + PERCENT_DELETE_ROW, + (int) (percentUnrelatedDeletes / PERCENT_DELETE_ROW), + 1); + } + } + } + + @Override + protected FileFormat fileFormat() { + return FileFormat.PARQUET; + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/ParquetWritersBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/ParquetWritersBenchmark.java new file mode 100644 index 000000000000..8bcd871a07da --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/ParquetWritersBenchmark.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.parquet; + +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.spark.source.WritersBenchmark; + +/** + * A benchmark that evaluates the performance of various Iceberg writers for Parquet data. + * + *

To run this benchmark for spark-4.0: + * ./gradlew -DsparkVersions=4.0 :iceberg-spark:iceberg-spark-4.0_2.13:jmh \ + * -PjmhIncludeRegex=ParquetWritersBenchmark \ + * -PjmhOutputPath=benchmark/parquet-writers-benchmark-result.txt + * + */ +public class ParquetWritersBenchmark extends WritersBenchmark { + + @Override + protected FileFormat fileFormat() { + return FileFormat.PARQUET; + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/vectorized/VectorizedReadDictionaryEncodedFlatParquetDataBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/vectorized/VectorizedReadDictionaryEncodedFlatParquetDataBenchmark.java new file mode 100644 index 000000000000..73d4f6211803 --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/vectorized/VectorizedReadDictionaryEncodedFlatParquetDataBenchmark.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.parquet.vectorized; + +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.lit; +import static org.apache.spark.sql.functions.pmod; +import static org.apache.spark.sql.functions.to_date; +import static org.apache.spark.sql.functions.to_timestamp; + +import java.util.Map; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.types.DataTypes; +import org.openjdk.jmh.annotations.Setup; + +/** + * Benchmark to compare performance of reading Parquet dictionary encoded data with a flat schema + * using vectorized Iceberg read path and the built-in file source in Spark. + * + *

To run this benchmark for spark-4.0: + * ./gradlew -DsparkVersions=4.0 :iceberg-spark:iceberg-spark-4.0_2.13:jmh \ + * -PjmhIncludeRegex=VectorizedReadDictionaryEncodedFlatParquetDataBenchmark \ + * -PjmhOutputPath=benchmark/results.txt + * + */ +public class VectorizedReadDictionaryEncodedFlatParquetDataBenchmark + extends VectorizedReadFlatParquetDataBenchmark { + + @Setup + @Override + public void setupBenchmark() { + setupSpark(true); + appendData(); + } + + @Override + Map parquetWriteProps() { + Map properties = Maps.newHashMap(); + properties.put(TableProperties.METADATA_COMPRESSION, "gzip"); + return properties; + } + + @Override + void appendData() { + Dataset df = idDF(); + df = withLongColumnDictEncoded(df); + df = withIntColumnDictEncoded(df); + df = withFloatColumnDictEncoded(df); + df = withDoubleColumnDictEncoded(df); + df = withBigDecimalColumnNotDictEncoded(df); // no dictionary for fixed len binary in Parquet v1 + df = withDecimalColumnDictEncoded(df); + df = withDateColumnDictEncoded(df); + df = withTimestampColumnDictEncoded(df); + df = withStringColumnDictEncoded(df); + df = df.drop("id"); + df.write().format("iceberg").mode(SaveMode.Append).save(table().location()); + } + + private static Column modColumn() { + return pmod(col("id"), lit(9)); + } + + private Dataset idDF() { + return spark().range(0, NUM_ROWS_PER_FILE * NUM_FILES, 1, NUM_FILES).toDF(); + } + + private static Dataset withLongColumnDictEncoded(Dataset df) { + return df.withColumn("longCol", modColumn().cast(DataTypes.LongType)); + } + + private static Dataset withIntColumnDictEncoded(Dataset df) { + return df.withColumn("intCol", modColumn().cast(DataTypes.IntegerType)); + } + + private static Dataset withFloatColumnDictEncoded(Dataset df) { + return df.withColumn("floatCol", modColumn().cast(DataTypes.FloatType)); + } + + private static Dataset withDoubleColumnDictEncoded(Dataset df) { + return df.withColumn("doubleCol", modColumn().cast(DataTypes.DoubleType)); + } + + private static Dataset withBigDecimalColumnNotDictEncoded(Dataset df) { + return df.withColumn("bigDecimalCol", modColumn().cast("decimal(20,5)")); + } + + private static Dataset withDecimalColumnDictEncoded(Dataset df) { + return df.withColumn("decimalCol", modColumn().cast("decimal(18,5)")); + } + + private static Dataset withDateColumnDictEncoded(Dataset df) { + Column days = modColumn().cast(DataTypes.ShortType); + return df.withColumn("dateCol", date_add(to_date(lit("04/12/2019"), "MM/dd/yyyy"), days)); + } + + private static Dataset withTimestampColumnDictEncoded(Dataset df) { + Column days = modColumn().cast(DataTypes.ShortType); + return df.withColumn( + "timestampCol", to_timestamp(date_add(to_date(lit("04/12/2019"), "MM/dd/yyyy"), days))); + } + + private static Dataset withStringColumnDictEncoded(Dataset df) { + return df.withColumn("stringCol", modColumn().cast(DataTypes.StringType)); + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/vectorized/VectorizedReadFlatParquetDataBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/vectorized/VectorizedReadFlatParquetDataBenchmark.java new file mode 100644 index 000000000000..6cf327c1cf81 --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/vectorized/VectorizedReadFlatParquetDataBenchmark.java @@ -0,0 +1,355 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.parquet.vectorized; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.current_date; +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.expr; +import static org.apache.spark.sql.functions.lit; +import static org.apache.spark.sql.functions.pmod; +import static org.apache.spark.sql.functions.when; + +import java.io.IOException; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceBenchmark; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +/** + * Benchmark to compare performance of reading Parquet data with a flat schema using vectorized + * Iceberg read path and the built-in file source in Spark. + * + *

To run this benchmark for spark-4.0: + * ./gradlew -DsparkVersions=4.0 :iceberg-spark:iceberg-spark-4.0_2.13:jmh \ + * -PjmhIncludeRegex=VectorizedReadFlatParquetDataBenchmark \ + * -PjmhOutputPath=benchmark/results.txt + * + */ +public class VectorizedReadFlatParquetDataBenchmark extends IcebergSourceBenchmark { + + static final int NUM_FILES = 5; + static final int NUM_ROWS_PER_FILE = 10_000_000; + + @Setup + public void setupBenchmark() { + setupSpark(); + appendData(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Override + protected Configuration initHadoopConf() { + return new Configuration(); + } + + @Override + protected Table initTable() { + // bigDecimalCol is big enough to be encoded as fix len binary (9 bytes), + // decimalCol is small enough to be encoded as a 64-bit int + Schema schema = + new Schema( + optional(1, "longCol", Types.LongType.get()), + optional(2, "intCol", Types.IntegerType.get()), + optional(3, "floatCol", Types.FloatType.get()), + optional(4, "doubleCol", Types.DoubleType.get()), + optional(5, "bigDecimalCol", Types.DecimalType.of(20, 5)), + optional(6, "decimalCol", Types.DecimalType.of(18, 5)), + optional(7, "dateCol", Types.DateType.get()), + optional(8, "timestampCol", Types.TimestampType.withZone()), + optional(9, "stringCol", Types.StringType.get())); + PartitionSpec partitionSpec = PartitionSpec.unpartitioned(); + HadoopTables tables = new HadoopTables(hadoopConf()); + Map properties = parquetWriteProps(); + return tables.create(schema, partitionSpec, properties, newTableLocation()); + } + + Map parquetWriteProps() { + Map properties = Maps.newHashMap(); + properties.put(TableProperties.METADATA_COMPRESSION, "gzip"); + properties.put(TableProperties.PARQUET_DICT_SIZE_BYTES, "1"); + return properties; + } + + void appendData() { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + Dataset df = + spark() + .range(NUM_ROWS_PER_FILE) + .withColumn( + "longCol", + when(pmod(col("id"), lit(10)).equalTo(lit(0)), lit(null)).otherwise(col("id"))) + .drop("id") + .withColumn("intCol", expr("CAST(longCol AS INT)")) + .withColumn("floatCol", expr("CAST(longCol AS FLOAT)")) + .withColumn("doubleCol", expr("CAST(longCol AS DOUBLE)")) + .withColumn("bigDecimalCol", expr("CAST(longCol AS DECIMAL(20, 5))")) + .withColumn("decimalCol", expr("CAST(longCol AS DECIMAL(18, 5))")) + .withColumn("dateCol", date_add(current_date(), fileNum)) + .withColumn("timestampCol", expr("TO_TIMESTAMP(dateCol)")) + .withColumn("stringCol", expr("CAST(longCol AS STRING)")); + appendAsFile(df); + } + } + + @Benchmark + @Threads(1) + public void readIntegersIcebergVectorized5k() { + withTableProperties( + tablePropsWithVectorizationEnabled(5000), + () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation).select("intCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readIntegersSparkVectorized5k() { + withSQLConf( + sparkConfWithVectorizationEnabled(5000), + () -> { + Dataset df = spark().read().parquet(dataLocation()).select("intCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readLongsIcebergVectorized5k() { + withTableProperties( + tablePropsWithVectorizationEnabled(5000), + () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation).select("longCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readLongsSparkVectorized5k() { + withSQLConf( + sparkConfWithVectorizationEnabled(5000), + () -> { + Dataset df = spark().read().parquet(dataLocation()).select("longCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFloatsIcebergVectorized5k() { + withTableProperties( + tablePropsWithVectorizationEnabled(5000), + () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation).select("floatCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFloatsSparkVectorized5k() { + withSQLConf( + sparkConfWithVectorizationEnabled(5000), + () -> { + Dataset df = spark().read().parquet(dataLocation()).select("floatCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readDoublesIcebergVectorized5k() { + withTableProperties( + tablePropsWithVectorizationEnabled(5000), + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).select("doubleCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readDoublesSparkVectorized5k() { + withSQLConf( + sparkConfWithVectorizationEnabled(5000), + () -> { + Dataset df = spark().read().parquet(dataLocation()).select("doubleCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readDecimalsIcebergVectorized5k() { + withTableProperties( + tablePropsWithVectorizationEnabled(5000), + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).select("decimalCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readDecimalsSparkVectorized5k() { + withSQLConf( + sparkConfWithVectorizationEnabled(5000), + () -> { + Dataset df = spark().read().parquet(dataLocation()).select("decimalCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readBigDecimalsIcebergVectorized5k() { + withTableProperties( + tablePropsWithVectorizationEnabled(5000), + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).select("bigDecimalCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readBigDecimalsSparkVectorized5k() { + withSQLConf( + sparkConfWithVectorizationEnabled(5000), + () -> { + Dataset df = spark().read().parquet(dataLocation()).select("bigDecimalCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readDatesIcebergVectorized5k() { + withTableProperties( + tablePropsWithVectorizationEnabled(5000), + () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation).select("dateCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readDatesSparkVectorized5k() { + withSQLConf( + sparkConfWithVectorizationEnabled(5000), + () -> { + Dataset df = spark().read().parquet(dataLocation()).select("dateCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readTimestampsIcebergVectorized5k() { + withTableProperties( + tablePropsWithVectorizationEnabled(5000), + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).select("timestampCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readTimestampsSparkVectorized5k() { + withSQLConf( + sparkConfWithVectorizationEnabled(5000), + () -> { + Dataset df = spark().read().parquet(dataLocation()).select("timestampCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readStringsIcebergVectorized5k() { + withTableProperties( + tablePropsWithVectorizationEnabled(5000), + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).select("stringCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readStringsSparkVectorized5k() { + withSQLConf( + sparkConfWithVectorizationEnabled(5000), + () -> { + Dataset df = spark().read().parquet(dataLocation()).select("stringCol"); + materialize(df); + }); + } + + private static Map tablePropsWithVectorizationEnabled(int batchSize) { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(TableProperties.PARQUET_VECTORIZATION_ENABLED, "true"); + tableProperties.put(TableProperties.PARQUET_BATCH_SIZE, String.valueOf(batchSize)); + return tableProperties; + } + + private static Map sparkConfWithVectorizationEnabled(int batchSize) { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "true"); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().key(), String.valueOf(batchSize)); + return conf; + } +} diff --git a/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/vectorized/VectorizedReadParquetDecimalBenchmark.java b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/vectorized/VectorizedReadParquetDecimalBenchmark.java new file mode 100644 index 000000000000..ccf28e3fdc77 --- /dev/null +++ b/spark/v4.0/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/vectorized/VectorizedReadParquetDecimalBenchmark.java @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.parquet.vectorized; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.expr; +import static org.apache.spark.sql.functions.lit; +import static org.apache.spark.sql.functions.pmod; +import static org.apache.spark.sql.functions.when; + +import java.io.IOException; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceBenchmark; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +/** + * Benchmark to compare performance of reading Parquet decimal data using vectorized Iceberg read + * path and the built-in file source in Spark. + * + *

To run this benchmark for spark-4.0: + * ./gradlew -DsparkVersions=4.0 :iceberg-spark:iceberg-spark-4.0_2.13:jmh \ + * -PjmhIncludeRegex=VectorizedReadParquetDecimalBenchmark \ + * -PjmhOutputPath=benchmark/results.txt + * + */ +public class VectorizedReadParquetDecimalBenchmark extends IcebergSourceBenchmark { + + static final int NUM_FILES = 5; + static final int NUM_ROWS_PER_FILE = 10_000_000; + + @Setup + public void setupBenchmark() { + setupSpark(); + appendData(); + // Allow unsafe memory access to avoid the costly check arrow does to check if index is within + // bounds + System.setProperty("arrow.enable_unsafe_memory_access", "true"); + // Disable expensive null check for every get(index) call. + // Iceberg manages nullability checks itself instead of relying on arrow. + System.setProperty("arrow.enable_null_check_for_get", "false"); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Override + protected Configuration initHadoopConf() { + return new Configuration(); + } + + @Override + protected Table initTable() { + Schema schema = + new Schema( + optional(1, "decimalCol1", Types.DecimalType.of(7, 2)), + optional(2, "decimalCol2", Types.DecimalType.of(15, 2)), + optional(3, "decimalCol3", Types.DecimalType.of(20, 2))); + PartitionSpec partitionSpec = PartitionSpec.unpartitioned(); + HadoopTables tables = new HadoopTables(hadoopConf()); + Map properties = parquetWriteProps(); + return tables.create(schema, partitionSpec, properties, newTableLocation()); + } + + Map parquetWriteProps() { + Map properties = Maps.newHashMap(); + properties.put(TableProperties.METADATA_COMPRESSION, "gzip"); + properties.put(TableProperties.PARQUET_DICT_SIZE_BYTES, "1"); + return properties; + } + + void appendData() { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + Dataset df = + spark() + .range(NUM_ROWS_PER_FILE) + .withColumn( + "longCol", + when(pmod(col("id"), lit(10)).equalTo(lit(0)), lit(null)).otherwise(col("id"))) + .drop("id") + .withColumn("decimalCol1", expr("CAST(longCol AS DECIMAL(7, 2))")) + .withColumn("decimalCol2", expr("CAST(longCol AS DECIMAL(15, 2))")) + .withColumn("decimalCol3", expr("CAST(longCol AS DECIMAL(20, 2))")) + .drop("longCol"); + appendAsFile(df); + } + } + + @Benchmark + @Threads(1) + public void readIntBackedDecimalsIcebergVectorized5k() { + withTableProperties( + tablePropsWithVectorizationEnabled(5000), + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).select("decimalCol1"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readIntBackedDecimalsSparkVectorized5k() { + withSQLConf( + sparkConfWithVectorizationEnabled(5000), + () -> { + Dataset df = spark().read().parquet(dataLocation()).select("decimalCol1"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readLongBackedDecimalsIcebergVectorized5k() { + withTableProperties( + tablePropsWithVectorizationEnabled(5000), + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).select("decimalCol2"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readLongBackedDecimalsSparkVectorized5k() { + withSQLConf( + sparkConfWithVectorizationEnabled(5000), + () -> { + Dataset df = spark().read().parquet(dataLocation()).select("decimalCol2"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readDecimalsIcebergVectorized5k() { + withTableProperties( + tablePropsWithVectorizationEnabled(5000), + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).select("decimalCol3"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readDecimalsSparkVectorized5k() { + withSQLConf( + sparkConfWithVectorizationEnabled(5000), + () -> { + Dataset df = spark().read().parquet(dataLocation()).select("decimalCol3"); + materialize(df); + }); + } + + private static Map tablePropsWithVectorizationEnabled(int batchSize) { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(TableProperties.PARQUET_VECTORIZATION_ENABLED, "true"); + tableProperties.put(TableProperties.PARQUET_BATCH_SIZE, String.valueOf(batchSize)); + return tableProperties; + } + + private static Map sparkConfWithVectorizationEnabled(int batchSize) { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "true"); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().key(), String.valueOf(batchSize)); + return conf; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/SparkDistributedDataScan.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/SparkDistributedDataScan.java new file mode 100644 index 000000000000..43ce2a303e2b --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/SparkDistributedDataScan.java @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.ClosingIterator; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.metrics.MetricsReporter; +import org.apache.iceberg.relocated.com.google.common.base.Joiner; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.JobGroupInfo; +import org.apache.iceberg.spark.JobGroupUtils; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.iceberg.spark.actions.ManifestFileBean; +import org.apache.iceberg.spark.source.SerializableTableWithSize; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.SparkSession; + +/** + * A batch data scan that can utilize Spark cluster resources for planning. + * + *

This scan remotely filters manifests, fetching only the relevant data and delete files to the + * driver. The delete file assignment is done locally after the remote filtering step. Such approach + * is beneficial if the remote parallelism is much higher than the number of driver cores. + * + *

This scan is best suited for queries with selective filters on lower/upper bounds across all + * partitions, or against poorly clustered metadata. This allows job planning to benefit from highly + * concurrent remote filtering while not incurring high serialization and data transfer costs. This + * class is also useful for full table scans over large tables but the cost of bringing data and + * delete file details to the driver may become noticeable. Make sure to follow the performance tips + * below in such cases. + * + *

Ensure the filtered metadata size doesn't exceed the driver's max result size. For large table + * scans, consider increasing `spark.driver.maxResultSize` to avoid job failures. + * + *

Performance tips: + * + *

+ */ +public class SparkDistributedDataScan extends BaseDistributedDataScan { + + private static final Joiner COMMA = Joiner.on(','); + private static final String DELETE_PLANNING_JOB_GROUP_ID = "DELETE-PLANNING"; + private static final String DATA_PLANNING_JOB_GROUP_ID = "DATA-PLANNING"; + + private final SparkSession spark; + private final JavaSparkContext sparkContext; + private final SparkReadConf readConf; + + private Broadcast tableBroadcast = null; + + public SparkDistributedDataScan(SparkSession spark, Table table, SparkReadConf readConf) { + this(spark, table, readConf, table.schema(), newTableScanContext(table)); + } + + private SparkDistributedDataScan( + SparkSession spark, + Table table, + SparkReadConf readConf, + Schema schema, + TableScanContext context) { + super(table, schema, context); + this.spark = spark; + this.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + this.readConf = readConf; + } + + @Override + protected BatchScan newRefinedScan( + Table newTable, Schema newSchema, TableScanContext newContext) { + return new SparkDistributedDataScan(spark, newTable, readConf, newSchema, newContext); + } + + @Override + protected int remoteParallelism() { + return readConf.parallelism(); + } + + @Override + protected PlanningMode dataPlanningMode() { + return readConf.dataPlanningMode(); + } + + @Override + protected boolean shouldCopyRemotelyPlannedDataFiles() { + return false; + } + + @Override + protected Iterable> planDataRemotely( + List dataManifests, boolean withColumnStats) { + JobGroupInfo info = new JobGroupInfo(DATA_PLANNING_JOB_GROUP_ID, jobDesc("data")); + return withJobGroupInfo(info, () -> doPlanDataRemotely(dataManifests, withColumnStats)); + } + + private Iterable> doPlanDataRemotely( + List dataManifests, boolean withColumnStats) { + scanMetrics().scannedDataManifests().increment(dataManifests.size()); + + JavaRDD dataFileRDD = + sparkContext + .parallelize(toBeans(dataManifests), dataManifests.size()) + .flatMap(new ReadDataManifest(tableBroadcast(), context(), withColumnStats)); + List> dataFileGroups = collectPartitions(dataFileRDD); + + int matchingFilesCount = dataFileGroups.stream().mapToInt(List::size).sum(); + int skippedFilesCount = liveFilesCount(dataManifests) - matchingFilesCount; + scanMetrics().skippedDataFiles().increment(skippedFilesCount); + + return Iterables.transform(dataFileGroups, CloseableIterable::withNoopClose); + } + + @Override + protected PlanningMode deletePlanningMode() { + return readConf.deletePlanningMode(); + } + + @Override + protected DeleteFileIndex planDeletesRemotely(List deleteManifests) { + JobGroupInfo info = new JobGroupInfo(DELETE_PLANNING_JOB_GROUP_ID, jobDesc("deletes")); + return withJobGroupInfo(info, () -> doPlanDeletesRemotely(deleteManifests)); + } + + private DeleteFileIndex doPlanDeletesRemotely(List deleteManifests) { + scanMetrics().scannedDeleteManifests().increment(deleteManifests.size()); + + List deleteFiles = + sparkContext + .parallelize(toBeans(deleteManifests), deleteManifests.size()) + .flatMap(new ReadDeleteManifest(tableBroadcast(), context())) + .collect(); + + int skippedFilesCount = liveFilesCount(deleteManifests) - deleteFiles.size(); + scanMetrics().skippedDeleteFiles().increment(skippedFilesCount); + + return DeleteFileIndex.builderFor(deleteFiles) + .specsById(table().specs()) + .caseSensitive(isCaseSensitive()) + .scanMetrics(scanMetrics()) + .build(); + } + + private T withJobGroupInfo(JobGroupInfo info, Supplier supplier) { + return JobGroupUtils.withJobGroupInfo(sparkContext, info, supplier); + } + + private String jobDesc(String type) { + List options = Lists.newArrayList(); + options.add("snapshot_id=" + snapshot().snapshotId()); + String optionsAsString = COMMA.join(options); + return String.format("Planning %s (%s) for %s", type, optionsAsString, table().name()); + } + + private List toBeans(List manifests) { + return manifests.stream().map(ManifestFileBean::fromManifest).collect(Collectors.toList()); + } + + private Broadcast
tableBroadcast() { + if (tableBroadcast == null) { + Table serializableTable = SerializableTableWithSize.copyOf(table()); + this.tableBroadcast = sparkContext.broadcast(serializableTable); + } + + return tableBroadcast; + } + + private List> collectPartitions(JavaRDD rdd) { + int[] partitionIds = IntStream.range(0, rdd.getNumPartitions()).toArray(); + return Arrays.asList(rdd.collectPartitions(partitionIds)); + } + + private int liveFilesCount(List manifests) { + return manifests.stream().mapToInt(this::liveFilesCount).sum(); + } + + private int liveFilesCount(ManifestFile manifest) { + return manifest.existingFilesCount() + manifest.addedFilesCount(); + } + + private static TableScanContext newTableScanContext(Table table) { + if (table instanceof BaseTable) { + MetricsReporter reporter = ((BaseTable) table).reporter(); + return ImmutableTableScanContext.builder().metricsReporter(reporter).build(); + } else { + return TableScanContext.empty(); + } + } + + private static class ReadDataManifest implements FlatMapFunction { + + private final Broadcast
table; + private final Expression filter; + private final boolean withStats; + private final boolean isCaseSensitive; + + ReadDataManifest(Broadcast
table, TableScanContext context, boolean withStats) { + this.table = table; + this.filter = context.rowFilter(); + this.withStats = withStats; + this.isCaseSensitive = context.caseSensitive(); + } + + @Override + public Iterator call(ManifestFileBean manifest) throws Exception { + FileIO io = table.value().io(); + Map specs = table.value().specs(); + return new ClosingIterator<>( + ManifestFiles.read(manifest, io, specs) + .select(withStats ? SCAN_WITH_STATS_COLUMNS : SCAN_COLUMNS) + .filterRows(filter) + .caseSensitive(isCaseSensitive) + .iterator()); + } + } + + private static class ReadDeleteManifest implements FlatMapFunction { + + private final Broadcast
table; + private final Expression filter; + private final boolean isCaseSensitive; + + ReadDeleteManifest(Broadcast
table, TableScanContext context) { + this.table = table; + this.filter = context.rowFilter(); + this.isCaseSensitive = context.caseSensitive(); + } + + @Override + public Iterator call(ManifestFileBean manifest) throws Exception { + FileIO io = table.value().io(); + Map specs = table.value().specs(); + return new ClosingIterator<>( + ManifestFiles.readDeleteManifest(manifest, io, specs) + .select(DELETE_SCAN_WITH_STATS_COLUMNS) + .filterRows(filter) + .caseSensitive(isCaseSensitive) + .iterator()); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/BaseCatalog.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/BaseCatalog.java new file mode 100644 index 000000000000..2082c0584608 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/BaseCatalog.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import org.apache.iceberg.spark.procedures.SparkProcedures; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.iceberg.spark.source.HasIcebergCatalog; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.StagingTableCatalog; +import org.apache.spark.sql.connector.catalog.SupportsNamespaces; +import org.apache.spark.sql.connector.iceberg.catalog.Procedure; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureCatalog; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +abstract class BaseCatalog + implements StagingTableCatalog, + ProcedureCatalog, + SupportsNamespaces, + HasIcebergCatalog, + SupportsFunctions { + private static final String USE_NULLABLE_QUERY_SCHEMA_CTAS_RTAS = "use-nullable-query-schema"; + private static final boolean USE_NULLABLE_QUERY_SCHEMA_CTAS_RTAS_DEFAULT = true; + + private boolean useNullableQuerySchema = USE_NULLABLE_QUERY_SCHEMA_CTAS_RTAS_DEFAULT; + + @Override + public Procedure loadProcedure(Identifier ident) throws NoSuchProcedureException { + String[] namespace = ident.namespace(); + String name = ident.name(); + + // namespace resolution is case insensitive until we have a way to configure case sensitivity in + // catalogs + if (isSystemNamespace(namespace)) { + ProcedureBuilder builder = SparkProcedures.newBuilder(name); + if (builder != null) { + return builder.withTableCatalog(this).build(); + } + } + + throw new NoSuchProcedureException(ident); + } + + @Override + public boolean isFunctionNamespace(String[] namespace) { + // Allow for empty namespace, as Spark's storage partitioned joins look up + // the corresponding functions to generate transforms for partitioning + // with an empty namespace, such as `bucket`. + // Otherwise, use `system` namespace. + return namespace.length == 0 || isSystemNamespace(namespace); + } + + @Override + public boolean isExistingNamespace(String[] namespace) { + return namespaceExists(namespace); + } + + @Override + public void initialize(String name, CaseInsensitiveStringMap options) { + this.useNullableQuerySchema = + PropertyUtil.propertyAsBoolean( + options, + USE_NULLABLE_QUERY_SCHEMA_CTAS_RTAS, + USE_NULLABLE_QUERY_SCHEMA_CTAS_RTAS_DEFAULT); + } + + @Override + public boolean useNullableQuerySchema() { + return useNullableQuerySchema; + } + + private static boolean isSystemNamespace(String[] namespace) { + return namespace.length == 1 && namespace[0].equalsIgnoreCase("system"); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/BaseFileRewriteCoordinator.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/BaseFileRewriteCoordinator.java new file mode 100644 index 000000000000..5c95475d3302 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/BaseFileRewriteCoordinator.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.util.Pair; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +abstract class BaseFileRewriteCoordinator> { + + private static final Logger LOG = LoggerFactory.getLogger(BaseFileRewriteCoordinator.class); + + private final Map, Set> resultMap = Maps.newConcurrentMap(); + + /** + * Called to persist the output of a rewrite action for a specific group. Since the write is done + * via a Spark Datasource, we have to propagate the result through this side-effect call. + * + * @param table table where the rewrite is occurring + * @param fileSetId the id used to identify the source set of files being rewritten + * @param newFiles the new files which have been written + */ + public void stageRewrite(Table table, String fileSetId, Set newFiles) { + LOG.debug( + "Staging the output for {} - fileset {} with {} files", + table.name(), + fileSetId, + newFiles.size()); + Pair id = toId(table, fileSetId); + resultMap.put(id, newFiles); + } + + public Set fetchNewFiles(Table table, String fileSetId) { + Pair id = toId(table, fileSetId); + Set result = resultMap.get(id); + ValidationException.check( + result != null, "No results for rewrite of file set %s in table %s", fileSetId, table); + + return result; + } + + public void clearRewrite(Table table, String fileSetId) { + LOG.debug("Removing entry for {} - id {}", table.name(), fileSetId); + Pair id = toId(table, fileSetId); + resultMap.remove(id); + } + + public Set fetchSetIds(Table table) { + return resultMap.keySet().stream() + .filter(e -> e.first().equals(Spark3Util.baseTableUUID(table))) + .map(Pair::second) + .collect(Collectors.toSet()); + } + + private Pair toId(Table table, String setId) { + return Pair.of(Spark3Util.baseTableUUID(table), setId); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/ChangelogIterator.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/ChangelogIterator.java new file mode 100644 index 000000000000..cc44b1f3992c --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/ChangelogIterator.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Iterator; +import java.util.Objects; +import java.util.Set; +import org.apache.iceberg.ChangelogOperation; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Iterators; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.StructType; + +/** An iterator that transforms rows from changelog tables within a single Spark task. */ +public abstract class ChangelogIterator implements Iterator { + protected static final String DELETE = ChangelogOperation.DELETE.name(); + protected static final String INSERT = ChangelogOperation.INSERT.name(); + protected static final String UPDATE_BEFORE = ChangelogOperation.UPDATE_BEFORE.name(); + protected static final String UPDATE_AFTER = ChangelogOperation.UPDATE_AFTER.name(); + + private final Iterator rowIterator; + private final int changeTypeIndex; + private final StructType rowType; + + protected ChangelogIterator(Iterator rowIterator, StructType rowType) { + this.rowIterator = rowIterator; + this.rowType = rowType; + this.changeTypeIndex = rowType.fieldIndex(MetadataColumns.CHANGE_TYPE.name()); + } + + protected int changeTypeIndex() { + return changeTypeIndex; + } + + protected StructType rowType() { + return rowType; + } + + protected String changeType(Row row) { + String changeType = row.getString(changeTypeIndex()); + Preconditions.checkNotNull(changeType, "Change type should not be null"); + return changeType; + } + + protected Iterator rowIterator() { + return rowIterator; + } + + /** + * Creates an iterator composing {@link RemoveCarryoverIterator} and {@link ComputeUpdateIterator} + * to remove carry-over rows and compute update rows + * + * @param rowIterator the iterator of rows from a changelog table + * @param rowType the schema of the rows + * @param identifierFields the names of the identifier columns, which determine if rows are the + * same + * @return a new iterator instance + */ + public static Iterator computeUpdates( + Iterator rowIterator, StructType rowType, String[] identifierFields) { + Iterator carryoverRemoveIterator = removeCarryovers(rowIterator, rowType); + ChangelogIterator changelogIterator = + new ComputeUpdateIterator(carryoverRemoveIterator, rowType, identifierFields); + return Iterators.filter(changelogIterator, Objects::nonNull); + } + + /** + * Creates an iterator that removes carry-over rows from a changelog table. + * + * @param rowIterator the iterator of rows from a changelog table + * @param rowType the schema of the rows + * @return a new iterator instance + */ + public static Iterator removeCarryovers(Iterator rowIterator, StructType rowType) { + RemoveCarryoverIterator changelogIterator = new RemoveCarryoverIterator(rowIterator, rowType); + return Iterators.filter(changelogIterator, Objects::nonNull); + } + + public static Iterator removeNetCarryovers(Iterator rowIterator, StructType rowType) { + ChangelogIterator changelogIterator = new RemoveNetCarryoverIterator(rowIterator, rowType); + return Iterators.filter(changelogIterator, Objects::nonNull); + } + + protected boolean isSameRecord(Row currentRow, Row nextRow, int[] indicesToIdentifySameRow) { + for (int idx : indicesToIdentifySameRow) { + if (isDifferentValue(currentRow, nextRow, idx)) { + return false; + } + } + + return true; + } + + protected boolean isDifferentValue(Row currentRow, Row nextRow, int idx) { + return !Objects.equals(nextRow.get(idx), currentRow.get(idx)); + } + + protected static int[] generateIndicesToIdentifySameRow( + int totalColumnCount, Set metadataColumnIndices) { + int[] indices = new int[totalColumnCount - metadataColumnIndices.size()]; + + for (int i = 0, j = 0; i < indices.length; i++) { + if (!metadataColumnIndices.contains(i)) { + indices[j] = i; + j++; + } + } + return indices; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/CommitMetadata.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/CommitMetadata.java new file mode 100644 index 000000000000..ea400a779235 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/CommitMetadata.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Map; +import java.util.concurrent.Callable; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.util.ExceptionUtil; + +/** utility class to accept thread local commit properties */ +public class CommitMetadata { + + private CommitMetadata() {} + + private static final ThreadLocal> COMMIT_PROPERTIES = + ThreadLocal.withInitial(ImmutableMap::of); + + /** + * running the code wrapped as a caller, and any snapshot committed within the callable object + * will be attached with the metadata defined in properties + * + * @param properties extra commit metadata to attach to the snapshot committed within callable. + * The prefix will be removed for properties starting with {@link + * SnapshotSummary#EXTRA_METADATA_PREFIX} + * @param callable the code to be executed + * @param exClass the expected type of exception which would be thrown from callable + */ + public static R withCommitProperties( + Map properties, Callable callable, Class exClass) throws E { + Map props = Maps.newHashMap(); + properties.forEach( + (k, v) -> props.put(k.replace(SnapshotSummary.EXTRA_METADATA_PREFIX, ""), v)); + + COMMIT_PROPERTIES.set(props); + try { + return callable.call(); + } catch (Throwable e) { + ExceptionUtil.castAndThrow(e, exClass); + return null; + } finally { + COMMIT_PROPERTIES.set(ImmutableMap.of()); + } + } + + public static Map commitProperties() { + return COMMIT_PROPERTIES.get(); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/ComputeUpdateIterator.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/ComputeUpdateIterator.java new file mode 100644 index 000000000000..6951c33e51aa --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/ComputeUpdateIterator.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.types.StructType; + +/** + * An iterator that finds delete/insert rows which represent an update, and converts them into + * update records from changelog tables within a single Spark task. It assumes that rows are sorted + * by identifier columns and change type. + * + *

For example, these two rows + * + *

    + *
  • (id=1, data='a', op='DELETE') + *
  • (id=1, data='b', op='INSERT') + *
+ * + *

will be marked as update-rows: + * + *

    + *
  • (id=1, data='a', op='UPDATE_BEFORE') + *
  • (id=1, data='b', op='UPDATE_AFTER') + *
+ */ +public class ComputeUpdateIterator extends ChangelogIterator { + + private final String[] identifierFields; + private final List identifierFieldIdx; + + private Row cachedRow = null; + + ComputeUpdateIterator(Iterator rowIterator, StructType rowType, String[] identifierFields) { + super(rowIterator, rowType); + this.identifierFieldIdx = + Arrays.stream(identifierFields).map(rowType::fieldIndex).collect(Collectors.toList()); + this.identifierFields = identifierFields; + } + + @Override + public boolean hasNext() { + if (cachedRow != null) { + return true; + } + return rowIterator().hasNext(); + } + + @Override + public Row next() { + // if there is an updated cached row, return it directly + if (cachedUpdateRecord()) { + Row row = cachedRow; + cachedRow = null; + return row; + } + + // either a cached record which is not an UPDATE or the next record in the iterator. + Row currentRow = currentRow(); + + if (changeType(currentRow).equals(DELETE) && rowIterator().hasNext()) { + Row nextRow = rowIterator().next(); + cachedRow = nextRow; + + if (sameLogicalRow(currentRow, nextRow)) { + Preconditions.checkState( + changeType(nextRow).equals(INSERT), + "Cannot compute updates because there are multiple rows with the same identifier" + + " fields([%s]). Please make sure the rows are unique.", + String.join(",", identifierFields)); + + currentRow = modify(currentRow, changeTypeIndex(), UPDATE_BEFORE); + cachedRow = modify(nextRow, changeTypeIndex(), UPDATE_AFTER); + } + } + + return currentRow; + } + + private Row modify(Row row, int valueIndex, Object value) { + if (row instanceof GenericRow) { + GenericRow genericRow = (GenericRow) row; + genericRow.values()[valueIndex] = value; + return genericRow; + } else { + Object[] values = new Object[row.size()]; + for (int index = 0; index < row.size(); index++) { + values[index] = row.get(index); + } + values[valueIndex] = value; + return RowFactory.create(values); + } + } + + private boolean cachedUpdateRecord() { + return cachedRow != null && changeType(cachedRow).equals(UPDATE_AFTER); + } + + private Row currentRow() { + if (cachedRow != null) { + Row row = cachedRow; + cachedRow = null; + return row; + } else { + return rowIterator().next(); + } + } + + private boolean sameLogicalRow(Row currentRow, Row nextRow) { + for (int idx : identifierFieldIdx) { + if (isDifferentValue(currentRow, nextRow, idx)) { + return false; + } + } + return true; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/ExtendedParser.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/ExtendedParser.java new file mode 100644 index 000000000000..19b3dd8f49be --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/ExtendedParser.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.List; +import org.apache.iceberg.NullOrder; +import org.apache.iceberg.SortDirection; +import org.apache.iceberg.expressions.Term; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.parser.ParserInterface; + +public interface ExtendedParser extends ParserInterface { + class RawOrderField { + private final Term term; + private final SortDirection direction; + private final NullOrder nullOrder; + + public RawOrderField(Term term, SortDirection direction, NullOrder nullOrder) { + this.term = term; + this.direction = direction; + this.nullOrder = nullOrder; + } + + public Term term() { + return term; + } + + public SortDirection direction() { + return direction; + } + + public NullOrder nullOrder() { + return nullOrder; + } + } + + static List parseSortOrder(SparkSession spark, String orderString) { + if (spark.sessionState().sqlParser() instanceof ExtendedParser) { + ExtendedParser parser = (ExtendedParser) spark.sessionState().sqlParser(); + try { + return parser.parseSortOrder(orderString); + } catch (AnalysisException e) { + throw new IllegalArgumentException( + String.format("Unable to parse sortOrder: %s", orderString), e); + } + } else { + throw new IllegalStateException( + "Cannot parse order: parser is not an Iceberg ExtendedParser"); + } + } + + List parseSortOrder(String orderString) throws AnalysisException; +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/FileRewriteCoordinator.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/FileRewriteCoordinator.java new file mode 100644 index 000000000000..432f7737d623 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/FileRewriteCoordinator.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import org.apache.iceberg.DataFile; + +public class FileRewriteCoordinator extends BaseFileRewriteCoordinator { + + private static final FileRewriteCoordinator INSTANCE = new FileRewriteCoordinator(); + + private FileRewriteCoordinator() {} + + public static FileRewriteCoordinator get() { + return INSTANCE; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/IcebergSpark.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/IcebergSpark.java new file mode 100644 index 000000000000..eb2420c0b254 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/IcebergSpark.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.function.Function; +import org.apache.iceberg.transforms.Transforms; +import org.apache.iceberg.types.Type; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; + +public class IcebergSpark { + private IcebergSpark() {} + + public static void registerBucketUDF( + SparkSession session, String funcName, DataType sourceType, int numBuckets) { + SparkTypeToType typeConverter = new SparkTypeToType(); + Type sourceIcebergType = typeConverter.atomic(sourceType); + Function bucket = Transforms.bucket(numBuckets).bind(sourceIcebergType); + session + .udf() + .register( + funcName, + value -> bucket.apply(SparkValueConverter.convert(sourceIcebergType, value)), + DataTypes.IntegerType); + } + + public static void registerTruncateUDF( + SparkSession session, String funcName, DataType sourceType, int width) { + SparkTypeToType typeConverter = new SparkTypeToType(); + Type sourceIcebergType = typeConverter.atomic(sourceType); + Function truncate = Transforms.truncate(width).bind(sourceIcebergType); + session + .udf() + .register( + funcName, + value -> truncate.apply(SparkValueConverter.convert(sourceIcebergType, value)), + sourceType); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/JobGroupInfo.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/JobGroupInfo.java new file mode 100644 index 000000000000..dc59fc70880e --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/JobGroupInfo.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +/** Captures information about the current job which is used for displaying on the UI */ +public class JobGroupInfo { + private final String groupId; + private final String description; + private final boolean interruptOnCancel; + + public JobGroupInfo(String groupId, String desc) { + this(groupId, desc, false); + } + + public JobGroupInfo(String groupId, String desc, boolean interruptOnCancel) { + this.groupId = groupId; + this.description = desc; + this.interruptOnCancel = interruptOnCancel; + } + + public String groupId() { + return groupId; + } + + public String description() { + return description; + } + + public boolean interruptOnCancel() { + return interruptOnCancel; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/JobGroupUtils.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/JobGroupUtils.java new file mode 100644 index 000000000000..a6aadf7ebd0e --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/JobGroupUtils.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.function.Supplier; +import org.apache.spark.SparkContext; +import org.apache.spark.SparkContext$; +import org.apache.spark.api.java.JavaSparkContext; + +public class JobGroupUtils { + + private static final String JOB_GROUP_ID = SparkContext$.MODULE$.SPARK_JOB_GROUP_ID(); + private static final String JOB_GROUP_DESC = SparkContext$.MODULE$.SPARK_JOB_DESCRIPTION(); + private static final String JOB_INTERRUPT_ON_CANCEL = + SparkContext$.MODULE$.SPARK_JOB_INTERRUPT_ON_CANCEL(); + + private JobGroupUtils() {} + + public static JobGroupInfo getJobGroupInfo(SparkContext sparkContext) { + String groupId = sparkContext.getLocalProperty(JOB_GROUP_ID); + String description = sparkContext.getLocalProperty(JOB_GROUP_DESC); + String interruptOnCancel = sparkContext.getLocalProperty(JOB_INTERRUPT_ON_CANCEL); + return new JobGroupInfo(groupId, description, Boolean.parseBoolean(interruptOnCancel)); + } + + public static void setJobGroupInfo(SparkContext sparkContext, JobGroupInfo info) { + sparkContext.setLocalProperty(JOB_GROUP_ID, info.groupId()); + sparkContext.setLocalProperty(JOB_GROUP_DESC, info.description()); + sparkContext.setLocalProperty( + JOB_INTERRUPT_ON_CANCEL, String.valueOf(info.interruptOnCancel())); + } + + public static T withJobGroupInfo( + JavaSparkContext sparkContext, JobGroupInfo info, Supplier supplier) { + return withJobGroupInfo(sparkContext.sc(), info, supplier); + } + + public static T withJobGroupInfo( + SparkContext sparkContext, JobGroupInfo info, Supplier supplier) { + JobGroupInfo previousInfo = getJobGroupInfo(sparkContext); + try { + setJobGroupInfo(sparkContext, info); + return supplier.get(); + } finally { + setJobGroupInfo(sparkContext, previousInfo); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/PathIdentifier.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/PathIdentifier.java new file mode 100644 index 000000000000..110af6b87de5 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/PathIdentifier.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.List; +import org.apache.iceberg.relocated.com.google.common.base.Joiner; +import org.apache.iceberg.relocated.com.google.common.base.Splitter; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.spark.sql.connector.catalog.Identifier; + +public class PathIdentifier implements Identifier { + private static final Splitter SPLIT = Splitter.on("/"); + private static final Joiner JOIN = Joiner.on("/"); + private final String[] namespace; + private final String location; + private final String name; + + public PathIdentifier(String location) { + this.location = location; + List pathParts = SPLIT.splitToList(location); + name = Iterables.getLast(pathParts); + namespace = + pathParts.size() > 1 + ? new String[] {JOIN.join(pathParts.subList(0, pathParts.size() - 1))} + : new String[0]; + } + + @Override + public String[] namespace() { + return namespace; + } + + @Override + public String name() { + return name; + } + + public String location() { + return location; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/PositionDeletesRewriteCoordinator.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/PositionDeletesRewriteCoordinator.java new file mode 100644 index 000000000000..c7568005e22f --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/PositionDeletesRewriteCoordinator.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import org.apache.iceberg.DeleteFile; + +public class PositionDeletesRewriteCoordinator extends BaseFileRewriteCoordinator { + + private static final PositionDeletesRewriteCoordinator INSTANCE = + new PositionDeletesRewriteCoordinator(); + + private PositionDeletesRewriteCoordinator() {} + + public static PositionDeletesRewriteCoordinator get() { + return INSTANCE; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/PruneColumnsWithReordering.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/PruneColumnsWithReordering.java new file mode 100644 index 000000000000..f76f12355f1f --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/PruneColumnsWithReordering.java @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Supplier; +import org.apache.iceberg.Schema; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Type.TypeID; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.BinaryType$; +import org.apache.spark.sql.types.BooleanType$; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateType$; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.DoubleType$; +import org.apache.spark.sql.types.FloatType$; +import org.apache.spark.sql.types.IntegerType$; +import org.apache.spark.sql.types.LongType$; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StringType$; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.TimestampNTZType$; +import org.apache.spark.sql.types.TimestampType$; + +public class PruneColumnsWithReordering extends TypeUtil.CustomOrderSchemaVisitor { + private final StructType requestedType; + private final Set filterRefs; + private DataType current = null; + + PruneColumnsWithReordering(StructType requestedType, Set filterRefs) { + this.requestedType = requestedType; + this.filterRefs = filterRefs; + } + + @Override + public Type schema(Schema schema, Supplier structResult) { + this.current = requestedType; + try { + return structResult.get(); + } finally { + this.current = null; + } + } + + @Override + public Type struct(Types.StructType struct, Iterable fieldResults) { + Preconditions.checkNotNull( + struct, "Cannot prune null struct. Pruning must start with a schema."); + Preconditions.checkArgument(current instanceof StructType, "Not a struct: %s", current); + StructType requestedStruct = (StructType) current; + + List fields = struct.fields(); + List types = Lists.newArrayList(fieldResults); + + boolean changed = false; + // use a LinkedHashMap to preserve the original order of filter fields that are not projected + Map projectedFields = Maps.newLinkedHashMap(); + for (int i = 0; i < fields.size(); i += 1) { + Types.NestedField field = fields.get(i); + Type type = types.get(i); + + if (type == null) { + changed = true; + + } else if (field.type() == type) { + projectedFields.put(field.name(), field); + + } else if (field.isOptional()) { + changed = true; + projectedFields.put( + field.name(), Types.NestedField.optional(field.fieldId(), field.name(), type)); + + } else { + changed = true; + projectedFields.put( + field.name(), Types.NestedField.required(field.fieldId(), field.name(), type)); + } + } + + // Construct a new struct with the projected struct's order + boolean reordered = false; + StructField[] requestedFields = requestedStruct.fields(); + List newFields = Lists.newArrayListWithExpectedSize(requestedFields.length); + for (int i = 0; i < requestedFields.length; i += 1) { + // fields are resolved by name because Spark only sees the current table schema. + String name = requestedFields[i].name(); + if (!fields.get(i).name().equals(name)) { + reordered = true; + } + newFields.add(projectedFields.remove(name)); + } + + // Add remaining filter fields that were not explicitly projected + if (!projectedFields.isEmpty()) { + newFields.addAll(projectedFields.values()); + changed = true; // order probably changed + } + + if (reordered || changed) { + return Types.StructType.of(newFields); + } + + return struct; + } + + @Override + public Type field(Types.NestedField field, Supplier fieldResult) { + Preconditions.checkArgument(current instanceof StructType, "Not a struct: %s", current); + StructType requestedStruct = (StructType) current; + + // fields are resolved by name because Spark only sees the current table schema. + if (requestedStruct.getFieldIndex(field.name()).isEmpty()) { + // make sure that filter fields are projected even if they aren't in the requested schema. + if (filterRefs.contains(field.fieldId())) { + return field.type(); + } + return null; + } + + int fieldIndex = requestedStruct.fieldIndex(field.name()); + StructField requestedField = requestedStruct.fields()[fieldIndex]; + + Preconditions.checkArgument( + requestedField.nullable() || field.isRequired(), + "Cannot project an optional field as non-null: %s", + field.name()); + + this.current = requestedField.dataType(); + try { + return fieldResult.get(); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException( + "Invalid projection for field " + field.name() + ": " + e.getMessage(), e); + } finally { + this.current = requestedStruct; + } + } + + @Override + public Type list(Types.ListType list, Supplier elementResult) { + Preconditions.checkArgument(current instanceof ArrayType, "Not an array: %s", current); + ArrayType requestedArray = (ArrayType) current; + + Preconditions.checkArgument( + requestedArray.containsNull() || !list.isElementOptional(), + "Cannot project an array of optional elements as required elements: %s", + requestedArray); + + this.current = requestedArray.elementType(); + try { + Type elementType = elementResult.get(); + if (list.elementType() == elementType) { + return list; + } + + // must be a projected element type, create a new list + if (list.isElementOptional()) { + return Types.ListType.ofOptional(list.elementId(), elementType); + } else { + return Types.ListType.ofRequired(list.elementId(), elementType); + } + } finally { + this.current = requestedArray; + } + } + + @Override + public Type map(Types.MapType map, Supplier keyResult, Supplier valueResult) { + Preconditions.checkArgument(current instanceof MapType, "Not a map: %s", current); + MapType requestedMap = (MapType) current; + + Preconditions.checkArgument( + requestedMap.valueContainsNull() || !map.isValueOptional(), + "Cannot project a map of optional values as required values: %s", + map); + Preconditions.checkArgument( + requestedMap.keyType() instanceof StringType, + "Invalid map key type (not string): %s", + requestedMap.keyType()); + + this.current = requestedMap.valueType(); + try { + Type valueType = valueResult.get(); + if (map.valueType() == valueType) { + return map; + } + + if (map.isValueOptional()) { + return Types.MapType.ofOptional(map.keyId(), map.valueId(), map.keyType(), valueType); + } else { + return Types.MapType.ofRequired(map.keyId(), map.valueId(), map.keyType(), valueType); + } + } finally { + this.current = requestedMap; + } + } + + @Override + public Type primitive(Type.PrimitiveType primitive) { + Set> expectedType = TYPES.get(primitive.typeId()); + Preconditions.checkArgument( + expectedType != null && expectedType.contains(current.getClass()), + "Cannot project %s to incompatible type: %s", + primitive, + current); + + // additional checks based on type + switch (primitive.typeId()) { + case DECIMAL: + Types.DecimalType decimal = (Types.DecimalType) primitive; + DecimalType requestedDecimal = (DecimalType) current; + Preconditions.checkArgument( + requestedDecimal.scale() == decimal.scale(), + "Cannot project decimal with incompatible scale: %s != %s", + requestedDecimal.scale(), + decimal.scale()); + Preconditions.checkArgument( + requestedDecimal.precision() >= decimal.precision(), + "Cannot project decimal with incompatible precision: %s < %s", + requestedDecimal.precision(), + decimal.precision()); + break; + case TIMESTAMP: + Types.TimestampType timestamp = (Types.TimestampType) primitive; + Preconditions.checkArgument( + timestamp.shouldAdjustToUTC(), + "Cannot project timestamp (without time zone) as timestamptz (with time zone)"); + break; + default: + } + + return primitive; + } + + private static final ImmutableMap>> TYPES = + ImmutableMap.>>builder() + .put(TypeID.BOOLEAN, ImmutableSet.of(BooleanType$.class)) + .put(TypeID.INTEGER, ImmutableSet.of(IntegerType$.class)) + .put(TypeID.LONG, ImmutableSet.of(LongType$.class)) + .put(TypeID.FLOAT, ImmutableSet.of(FloatType$.class)) + .put(TypeID.DOUBLE, ImmutableSet.of(DoubleType$.class)) + .put(TypeID.DATE, ImmutableSet.of(DateType$.class)) + .put(TypeID.TIMESTAMP, ImmutableSet.of(TimestampType$.class, TimestampNTZType$.class)) + .put(TypeID.DECIMAL, ImmutableSet.of(DecimalType.class)) + .put(TypeID.UUID, ImmutableSet.of(StringType$.class)) + .put(TypeID.STRING, ImmutableSet.of(StringType$.class)) + .put(TypeID.FIXED, ImmutableSet.of(BinaryType$.class)) + .put(TypeID.BINARY, ImmutableSet.of(BinaryType.class)) + .buildOrThrow(); +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/PruneColumnsWithoutReordering.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/PruneColumnsWithoutReordering.java new file mode 100644 index 000000000000..fbd21f737450 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/PruneColumnsWithoutReordering.java @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.List; +import java.util.Set; +import java.util.function.Supplier; +import org.apache.iceberg.Schema; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Type.TypeID; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.BinaryType$; +import org.apache.spark.sql.types.BooleanType$; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateType$; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.DoubleType$; +import org.apache.spark.sql.types.FloatType$; +import org.apache.spark.sql.types.IntegerType$; +import org.apache.spark.sql.types.LongType$; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.StringType$; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.TimestampNTZType$; +import org.apache.spark.sql.types.TimestampType$; + +public class PruneColumnsWithoutReordering extends TypeUtil.CustomOrderSchemaVisitor { + private final StructType requestedType; + private final Set filterRefs; + private DataType current = null; + + PruneColumnsWithoutReordering(StructType requestedType, Set filterRefs) { + this.requestedType = requestedType; + this.filterRefs = filterRefs; + } + + @Override + public Type schema(Schema schema, Supplier structResult) { + this.current = requestedType; + try { + return structResult.get(); + } finally { + this.current = null; + } + } + + @Override + public Type struct(Types.StructType struct, Iterable fieldResults) { + Preconditions.checkNotNull( + struct, "Cannot prune null struct. Pruning must start with a schema."); + Preconditions.checkArgument(current instanceof StructType, "Not a struct: %s", current); + + List fields = struct.fields(); + List types = Lists.newArrayList(fieldResults); + + boolean changed = false; + List newFields = Lists.newArrayListWithExpectedSize(types.size()); + for (int i = 0; i < fields.size(); i += 1) { + Types.NestedField field = fields.get(i); + Type type = types.get(i); + + if (type == null) { + changed = true; + + } else if (field.type() == type) { + newFields.add(field); + + } else if (field.isOptional()) { + changed = true; + newFields.add(Types.NestedField.optional(field.fieldId(), field.name(), type)); + + } else { + changed = true; + newFields.add(Types.NestedField.required(field.fieldId(), field.name(), type)); + } + } + + if (changed) { + return Types.StructType.of(newFields); + } + + return struct; + } + + @Override + public Type field(Types.NestedField field, Supplier fieldResult) { + Preconditions.checkArgument(current instanceof StructType, "Not a struct: %s", current); + StructType requestedStruct = (StructType) current; + + // fields are resolved by name because Spark only sees the current table schema. + if (requestedStruct.getFieldIndex(field.name()).isEmpty()) { + // make sure that filter fields are projected even if they aren't in the requested schema. + if (filterRefs.contains(field.fieldId())) { + return field.type(); + } + return null; + } + + int fieldIndex = requestedStruct.fieldIndex(field.name()); + StructField requestedField = requestedStruct.fields()[fieldIndex]; + + Preconditions.checkArgument( + requestedField.nullable() || field.isRequired(), + "Cannot project an optional field as non-null: %s", + field.name()); + + this.current = requestedField.dataType(); + try { + return fieldResult.get(); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException( + "Invalid projection for field " + field.name() + ": " + e.getMessage(), e); + } finally { + this.current = requestedStruct; + } + } + + @Override + public Type list(Types.ListType list, Supplier elementResult) { + Preconditions.checkArgument(current instanceof ArrayType, "Not an array: %s", current); + ArrayType requestedArray = (ArrayType) current; + + Preconditions.checkArgument( + requestedArray.containsNull() || !list.isElementOptional(), + "Cannot project an array of optional elements as required elements: %s", + requestedArray); + + this.current = requestedArray.elementType(); + try { + Type elementType = elementResult.get(); + if (list.elementType() == elementType) { + return list; + } + + // must be a projected element type, create a new list + if (list.isElementOptional()) { + return Types.ListType.ofOptional(list.elementId(), elementType); + } else { + return Types.ListType.ofRequired(list.elementId(), elementType); + } + } finally { + this.current = requestedArray; + } + } + + @Override + public Type map(Types.MapType map, Supplier keyResult, Supplier valueResult) { + Preconditions.checkArgument(current instanceof MapType, "Not a map: %s", current); + MapType requestedMap = (MapType) current; + + Preconditions.checkArgument( + requestedMap.valueContainsNull() || !map.isValueOptional(), + "Cannot project a map of optional values as required values: %s", + map); + + this.current = requestedMap.valueType(); + try { + Type valueType = valueResult.get(); + if (map.valueType() == valueType) { + return map; + } + + if (map.isValueOptional()) { + return Types.MapType.ofOptional(map.keyId(), map.valueId(), map.keyType(), valueType); + } else { + return Types.MapType.ofRequired(map.keyId(), map.valueId(), map.keyType(), valueType); + } + } finally { + this.current = requestedMap; + } + } + + @Override + public Type primitive(Type.PrimitiveType primitive) { + Set> expectedType = TYPES.get(primitive.typeId()); + Preconditions.checkArgument( + expectedType != null && expectedType.contains(current.getClass()), + "Cannot project %s to incompatible type: %s", + primitive, + current); + + // additional checks based on type + switch (primitive.typeId()) { + case DECIMAL: + Types.DecimalType decimal = (Types.DecimalType) primitive; + DecimalType requestedDecimal = (DecimalType) current; + Preconditions.checkArgument( + requestedDecimal.scale() == decimal.scale(), + "Cannot project decimal with incompatible scale: %s != %s", + requestedDecimal.scale(), + decimal.scale()); + Preconditions.checkArgument( + requestedDecimal.precision() >= decimal.precision(), + "Cannot project decimal with incompatible precision: %s < %s", + requestedDecimal.precision(), + decimal.precision()); + break; + default: + } + + return primitive; + } + + private static final ImmutableMap>> TYPES = + ImmutableMap.>>builder() + .put(TypeID.BOOLEAN, ImmutableSet.of(BooleanType$.class)) + .put(TypeID.INTEGER, ImmutableSet.of(IntegerType$.class)) + .put(TypeID.LONG, ImmutableSet.of(LongType$.class)) + .put(TypeID.FLOAT, ImmutableSet.of(FloatType$.class)) + .put(TypeID.DOUBLE, ImmutableSet.of(DoubleType$.class)) + .put(TypeID.DATE, ImmutableSet.of(DateType$.class)) + .put(TypeID.TIMESTAMP, ImmutableSet.of(TimestampType$.class, TimestampNTZType$.class)) + .put(TypeID.DECIMAL, ImmutableSet.of(DecimalType.class)) + .put(TypeID.UUID, ImmutableSet.of(StringType$.class)) + .put(TypeID.STRING, ImmutableSet.of(StringType$.class)) + .put(TypeID.FIXED, ImmutableSet.of(BinaryType$.class)) + .put(TypeID.BINARY, ImmutableSet.of(BinaryType$.class)) + .buildOrThrow(); +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/RemoveCarryoverIterator.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/RemoveCarryoverIterator.java new file mode 100644 index 000000000000..2e90dc7749d1 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/RemoveCarryoverIterator.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Iterator; +import java.util.Set; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.StructType; + +/** + * An iterator that removes the carry-over rows from changelog tables within a single Spark task. It + * assumes that rows are partitioned by identifier(or all) columns, and it is sorted by both + * identifier(or all) columns and change type. + * + *

Carry-over rows are the result of a removal and insertion of the same row within an operation + * because of the copy-on-write mechanism. For example, given a file which contains row1 (id=1, + * data='a') and row2 (id=2, data='b'). A copy-on-write delete of row2 would require erasing this + * file and preserving row1 in a new file. The change-log table would report this as follows, + * despite it not being an actual change to the table. + * + *

    + *
  • (id=1, data='a', op='DELETE') + *
  • (id=1, data='a', op='INSERT') + *
  • (id=2, data='b', op='DELETE') + *
+ * + * The iterator finds the carry-over rows and removes them from the result. For example, the above + * rows will be converted to: + * + *
    + *
  • (id=2, data='b', op='DELETE') + *
+ */ +class RemoveCarryoverIterator extends ChangelogIterator { + private final int[] indicesToIdentifySameRow; + + private Row cachedDeletedRow = null; + private long deletedRowCount = 0; + private Row cachedNextRecord = null; + + RemoveCarryoverIterator(Iterator rowIterator, StructType rowType) { + super(rowIterator, rowType); + this.indicesToIdentifySameRow = generateIndicesToIdentifySameRow(); + } + + @Override + public boolean hasNext() { + if (hasCachedDeleteRow() || cachedNextRecord != null) { + return true; + } + return rowIterator().hasNext(); + } + + @Override + public Row next() { + Row currentRow; + + if (returnCachedDeleteRow()) { + // Non-carryover delete rows found. One or more identical delete rows were seen followed by a + // non-identical row. This means none of the delete rows were carry over rows. Emit one + // delete row and decrease the amount of delete rows seen. + deletedRowCount--; + currentRow = cachedDeletedRow; + if (deletedRowCount == 0) { + cachedDeletedRow = null; + } + return currentRow; + } else if (cachedNextRecord != null) { + currentRow = cachedNextRecord; + cachedNextRecord = null; + } else { + currentRow = rowIterator().next(); + } + + // If the current row is a delete row, drain all identical delete rows + if (changeType(currentRow).equals(DELETE) && rowIterator().hasNext()) { + cachedDeletedRow = currentRow; + deletedRowCount = 1; + + Row nextRow = rowIterator().next(); + + // drain all identical delete rows when there is at least one cached delete row and the next + // row is the same record + while (nextRow != null + && cachedDeletedRow != null + && isSameRecord(cachedDeletedRow, nextRow, indicesToIdentifySameRow)) { + if (changeType(nextRow).equals(INSERT)) { + deletedRowCount--; + if (deletedRowCount == 0) { + cachedDeletedRow = null; + } + } else { + deletedRowCount++; + } + + if (rowIterator().hasNext()) { + nextRow = rowIterator().next(); + } else { + nextRow = null; + } + } + + cachedNextRecord = nextRow; + return null; + } else { + // either there is no cached delete row or the current row is not a delete row + return currentRow; + } + } + + /** + * The iterator returns a cached delete row if there are delete rows cached and the next row is + * not the same record or there is no next row. + */ + private boolean returnCachedDeleteRow() { + return hitBoundary() && hasCachedDeleteRow(); + } + + private boolean hitBoundary() { + return !rowIterator().hasNext() || cachedNextRecord != null; + } + + private boolean hasCachedDeleteRow() { + return cachedDeletedRow != null; + } + + private int[] generateIndicesToIdentifySameRow() { + Set metadataColumnIndices = Sets.newHashSet(changeTypeIndex()); + return generateIndicesToIdentifySameRow(rowType().size(), metadataColumnIndices); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/RemoveNetCarryoverIterator.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/RemoveNetCarryoverIterator.java new file mode 100644 index 000000000000..941e4a4731e2 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/RemoveNetCarryoverIterator.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Iterator; +import java.util.Set; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.StructType; + +/** + * This class computes the net changes across multiple snapshots. It is different from {@link + * org.apache.iceberg.spark.RemoveCarryoverIterator}, which only removes carry-over rows within a + * single snapshot. It takes a row iterator, and assumes the following: + * + *
    + *
  • The row iterator is partitioned by all columns. + *
  • The row iterator is sorted by all columns, change order, and change type. The change order + * is 1-to-1 mapping to snapshot id. + *
+ */ +public class RemoveNetCarryoverIterator extends ChangelogIterator { + + private final int[] indicesToIdentifySameRow; + + private Row cachedNextRow; + private Row cachedRow; + private long cachedRowCount; + + protected RemoveNetCarryoverIterator(Iterator rowIterator, StructType rowType) { + super(rowIterator, rowType); + this.indicesToIdentifySameRow = generateIndicesToIdentifySameRow(); + } + + @Override + public boolean hasNext() { + if (cachedRowCount > 0) { + return true; + } + + if (cachedNextRow != null) { + return true; + } + + return rowIterator().hasNext(); + } + + @Override + public Row next() { + // if there are cached rows, return one of them from the beginning + if (cachedRowCount > 0) { + cachedRowCount--; + return cachedRow; + } + + cachedRow = getCurrentRow(); + // return it directly if there is no more rows + if (!rowIterator().hasNext()) { + return cachedRow; + } + cachedRowCount = 1; + + cachedNextRow = rowIterator().next(); + + // pull rows from the iterator until two consecutive rows are different + while (isSameRecord(cachedRow, cachedNextRow, indicesToIdentifySameRow)) { + if (oppositeChangeType(cachedRow, cachedNextRow)) { + // two rows with opposite change types means no net changes, remove both + cachedRowCount--; + } else { + // two rows with same change types means potential net changes, cache the next row + cachedRowCount++; + } + + // stop pulling rows if there is no more rows or the next row is different + if (cachedRowCount <= 0 || !rowIterator().hasNext()) { + // reset the cached next row if there is no more rows + cachedNextRow = null; + break; + } + + cachedNextRow = rowIterator().next(); + } + + return null; + } + + private Row getCurrentRow() { + Row currentRow; + if (cachedNextRow != null) { + currentRow = cachedNextRow; + cachedNextRow = null; + } else { + currentRow = rowIterator().next(); + } + return currentRow; + } + + private boolean oppositeChangeType(Row currentRow, Row nextRow) { + return (changeType(nextRow).equals(INSERT) && changeType(currentRow).equals(DELETE)) + || (changeType(nextRow).equals(DELETE) && changeType(currentRow).equals(INSERT)); + } + + private int[] generateIndicesToIdentifySameRow() { + Set metadataColumnIndices = + Sets.newHashSet( + rowType().fieldIndex(MetadataColumns.CHANGE_ORDINAL.name()), + rowType().fieldIndex(MetadataColumns.COMMIT_SNAPSHOT_ID.name()), + changeTypeIndex()); + return generateIndicesToIdentifySameRow(rowType().size(), metadataColumnIndices); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/RollbackStagedTable.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/RollbackStagedTable.java new file mode 100644 index 000000000000..bc8a966488ee --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/RollbackStagedTable.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Map; +import java.util.Set; +import java.util.function.Consumer; +import java.util.function.Function; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.StagedTable; +import org.apache.spark.sql.connector.catalog.SupportsDelete; +import org.apache.spark.sql.connector.catalog.SupportsRead; +import org.apache.spark.sql.connector.catalog.SupportsWrite; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCapability; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.WriteBuilder; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +/** + * An implementation of StagedTable that mimics the behavior of Spark's non-atomic CTAS and RTAS. + * + *

A Spark catalog can implement StagingTableCatalog to support atomic operations by producing + * StagedTable. But if a catalog implements StagingTableCatalog, Spark expects the catalog to be + * able to produce a StagedTable for any table loaded by the catalog. This assumption doesn't always + * work, as in the case of {@link SparkSessionCatalog}, which supports atomic operations can produce + * a StagedTable for Iceberg tables, but wraps the session catalog and cannot necessarily produce a + * working StagedTable implementation for tables that it loads. + * + *

The work-around is this class, which implements the StagedTable interface but does not have + * atomic behavior. Instead, the StagedTable interface is used to implement the behavior of the + * non-atomic SQL plans that will create a table, write, and will drop the table to roll back. + * + *

This StagedTable implements SupportsRead, SupportsWrite, and SupportsDelete by passing the + * calls to the real table. Implementing those interfaces is safe because Spark will not use them + * unless the table supports them and returns the corresponding capabilities from {@link + * #capabilities()}. + */ +public class RollbackStagedTable + implements StagedTable, SupportsRead, SupportsWrite, SupportsDelete { + private final TableCatalog catalog; + private final Identifier ident; + private final Table table; + + public RollbackStagedTable(TableCatalog catalog, Identifier ident, Table table) { + this.catalog = catalog; + this.ident = ident; + this.table = table; + } + + @Override + public void commitStagedChanges() { + // the changes have already been committed to the table at the end of the write + } + + @Override + public void abortStagedChanges() { + // roll back changes by dropping the table + catalog.dropTable(ident); + } + + @Override + public String name() { + return table.name(); + } + + @Override + public StructType schema() { + return table.schema(); + } + + @Override + public Transform[] partitioning() { + return table.partitioning(); + } + + @Override + public Map properties() { + return table.properties(); + } + + @Override + public Set capabilities() { + return table.capabilities(); + } + + @Override + public void deleteWhere(Filter[] filters) { + call(SupportsDelete.class, t -> t.deleteWhere(filters)); + } + + @Override + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { + return callReturning(SupportsRead.class, t -> t.newScanBuilder(options)); + } + + @Override + public WriteBuilder newWriteBuilder(LogicalWriteInfo info) { + return callReturning(SupportsWrite.class, t -> t.newWriteBuilder(info)); + } + + private void call(Class requiredClass, Consumer task) { + callReturning( + requiredClass, + inst -> { + task.accept(inst); + return null; + }); + } + + private R callReturning(Class requiredClass, Function task) { + if (requiredClass.isInstance(table)) { + return task.apply(requiredClass.cast(table)); + } else { + throw new UnsupportedOperationException( + String.format( + "Table does not implement %s: %s (%s)", + requiredClass.getSimpleName(), table.name(), table.getClass().getName())); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/ScanTaskSetManager.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/ScanTaskSetManager.java new file mode 100644 index 000000000000..cab40d103171 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/ScanTaskSetManager.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.iceberg.ScanTask; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.util.Pair; + +public class ScanTaskSetManager { + + private static final ScanTaskSetManager INSTANCE = new ScanTaskSetManager(); + + private final Map, List> tasksMap = + Maps.newConcurrentMap(); + + private ScanTaskSetManager() {} + + public static ScanTaskSetManager get() { + return INSTANCE; + } + + public void stageTasks(Table table, String setId, List tasks) { + Preconditions.checkArgument( + tasks != null && !tasks.isEmpty(), "Cannot stage null or empty tasks"); + Pair id = toId(table, setId); + tasksMap.put(id, tasks); + } + + @SuppressWarnings("unchecked") + public List fetchTasks(Table table, String setId) { + Pair id = toId(table, setId); + return (List) tasksMap.get(id); + } + + @SuppressWarnings("unchecked") + public List removeTasks(Table table, String setId) { + Pair id = toId(table, setId); + return (List) tasksMap.remove(id); + } + + public Set fetchSetIds(Table table) { + return tasksMap.keySet().stream() + .filter(e -> e.first().equals(Spark3Util.baseTableUUID(table))) + .map(Pair::second) + .collect(Collectors.toSet()); + } + + private Pair toId(Table table, String setId) { + return Pair.of(Spark3Util.baseTableUUID(table), setId); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SortOrderToSpark.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SortOrderToSpark.java new file mode 100644 index 000000000000..781f61b33f0e --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SortOrderToSpark.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Map; +import org.apache.iceberg.NullOrder; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortDirection; +import org.apache.iceberg.transforms.SortOrderVisitor; +import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.NullOrdering; +import org.apache.spark.sql.connector.expressions.SortOrder; + +class SortOrderToSpark implements SortOrderVisitor { + + private final Map quotedNameById; + + SortOrderToSpark(Schema schema) { + this.quotedNameById = SparkSchemaUtil.indexQuotedNameById(schema); + } + + @Override + public SortOrder field(String sourceName, int id, SortDirection direction, NullOrder nullOrder) { + return Expressions.sort( + Expressions.column(quotedName(id)), toSpark(direction), toSpark(nullOrder)); + } + + @Override + public SortOrder bucket( + String sourceName, int id, int width, SortDirection direction, NullOrder nullOrder) { + return Expressions.sort( + Expressions.bucket(width, quotedName(id)), toSpark(direction), toSpark(nullOrder)); + } + + @Override + public SortOrder truncate( + String sourceName, int id, int width, SortDirection direction, NullOrder nullOrder) { + return Expressions.sort( + Expressions.apply( + "truncate", Expressions.literal(width), Expressions.column(quotedName(id))), + toSpark(direction), + toSpark(nullOrder)); + } + + @Override + public SortOrder year(String sourceName, int id, SortDirection direction, NullOrder nullOrder) { + return Expressions.sort( + Expressions.years(quotedName(id)), toSpark(direction), toSpark(nullOrder)); + } + + @Override + public SortOrder month(String sourceName, int id, SortDirection direction, NullOrder nullOrder) { + return Expressions.sort( + Expressions.months(quotedName(id)), toSpark(direction), toSpark(nullOrder)); + } + + @Override + public SortOrder day(String sourceName, int id, SortDirection direction, NullOrder nullOrder) { + return Expressions.sort( + Expressions.days(quotedName(id)), toSpark(direction), toSpark(nullOrder)); + } + + @Override + public SortOrder hour(String sourceName, int id, SortDirection direction, NullOrder nullOrder) { + return Expressions.sort( + Expressions.hours(quotedName(id)), toSpark(direction), toSpark(nullOrder)); + } + + private String quotedName(int id) { + return quotedNameById.get(id); + } + + private org.apache.spark.sql.connector.expressions.SortDirection toSpark( + SortDirection direction) { + if (direction == SortDirection.ASC) { + return org.apache.spark.sql.connector.expressions.SortDirection.ASCENDING; + } else { + return org.apache.spark.sql.connector.expressions.SortDirection.DESCENDING; + } + } + + private NullOrdering toSpark(NullOrder nullOrder) { + return nullOrder == NullOrder.NULLS_FIRST ? NullOrdering.NULLS_FIRST : NullOrdering.NULLS_LAST; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/Spark3Util.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/Spark3Util.java new file mode 100644 index 000000000000..af0fa84f67a1 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/Spark3Util.java @@ -0,0 +1,1044 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.BaseMetadataTable; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.NullOrder; +import org.apache.iceberg.PartitionField; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.UpdateProperties; +import org.apache.iceberg.UpdateSchema; +import org.apache.iceberg.catalog.Catalog; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.expressions.BoundPredicate; +import org.apache.iceberg.expressions.ExpressionVisitors; +import org.apache.iceberg.expressions.Term; +import org.apache.iceberg.expressions.UnboundPredicate; +import org.apache.iceberg.expressions.UnboundTerm; +import org.apache.iceberg.expressions.UnboundTransform; +import org.apache.iceberg.expressions.Zorder; +import org.apache.iceberg.relocated.com.google.common.base.Joiner; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.io.BaseEncoding; +import org.apache.iceberg.spark.SparkTableUtil.SparkPartition; +import org.apache.iceberg.spark.source.HasIcebergCatalog; +import org.apache.iceberg.spark.source.SparkTable; +import org.apache.iceberg.transforms.PartitionSpecVisitor; +import org.apache.iceberg.transforms.SortOrderVisitor; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.ByteBuffers; +import org.apache.iceberg.util.Pair; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.CatalystTypeConverters; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.apache.spark.sql.catalyst.parser.ParserInterface; +import org.apache.spark.sql.connector.catalog.CatalogManager; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.catalog.TableChange; +import org.apache.spark.sql.connector.expressions.Expression; +import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.Literal; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.SortOrder; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.execution.datasources.FileStatusCache; +import org.apache.spark.sql.execution.datasources.FileStatusWithMetadata; +import org.apache.spark.sql.execution.datasources.InMemoryFileIndex; +import org.apache.spark.sql.execution.datasources.PartitionDirectory; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import scala.Option; +import scala.collection.JavaConverters; +import scala.collection.immutable.Seq; + +public class Spark3Util { + + private static final Set RESERVED_PROPERTIES = + ImmutableSet.of(TableCatalog.PROP_LOCATION, TableCatalog.PROP_PROVIDER); + private static final Joiner DOT = Joiner.on("."); + + private Spark3Util() {} + + public static CaseInsensitiveStringMap setOption( + String key, String value, CaseInsensitiveStringMap options) { + Map newOptions = Maps.newHashMap(); + newOptions.putAll(options); + newOptions.put(key, value); + return new CaseInsensitiveStringMap(newOptions); + } + + public static Map rebuildCreateProperties(Map createProperties) { + ImmutableMap.Builder tableProperties = ImmutableMap.builder(); + createProperties.entrySet().stream() + .filter(entry -> !RESERVED_PROPERTIES.contains(entry.getKey())) + .forEach(tableProperties::put); + + String provider = createProperties.get(TableCatalog.PROP_PROVIDER); + if ("parquet".equalsIgnoreCase(provider)) { + tableProperties.put(TableProperties.DEFAULT_FILE_FORMAT, "parquet"); + } else if ("avro".equalsIgnoreCase(provider)) { + tableProperties.put(TableProperties.DEFAULT_FILE_FORMAT, "avro"); + } else if ("orc".equalsIgnoreCase(provider)) { + tableProperties.put(TableProperties.DEFAULT_FILE_FORMAT, "orc"); + } else if (provider != null && !"iceberg".equalsIgnoreCase(provider)) { + throw new IllegalArgumentException("Unsupported format in USING: " + provider); + } + + return tableProperties.build(); + } + + /** + * Applies a list of Spark table changes to an {@link UpdateProperties} operation. + * + * @param pendingUpdate an uncommitted UpdateProperties operation to configure + * @param changes a list of Spark table changes + * @return the UpdateProperties operation configured with the changes + */ + public static UpdateProperties applyPropertyChanges( + UpdateProperties pendingUpdate, List changes) { + for (TableChange change : changes) { + if (change instanceof TableChange.SetProperty) { + TableChange.SetProperty set = (TableChange.SetProperty) change; + pendingUpdate.set(set.property(), set.value()); + + } else if (change instanceof TableChange.RemoveProperty) { + TableChange.RemoveProperty remove = (TableChange.RemoveProperty) change; + pendingUpdate.remove(remove.property()); + + } else { + throw new UnsupportedOperationException("Cannot apply unknown table change: " + change); + } + } + + return pendingUpdate; + } + + /** + * Applies a list of Spark table changes to an {@link UpdateSchema} operation. + * + * @param pendingUpdate an uncommitted UpdateSchema operation to configure + * @param changes a list of Spark table changes + * @return the UpdateSchema operation configured with the changes + */ + public static UpdateSchema applySchemaChanges( + UpdateSchema pendingUpdate, List changes) { + for (TableChange change : changes) { + if (change instanceof TableChange.AddColumn) { + apply(pendingUpdate, (TableChange.AddColumn) change); + + } else if (change instanceof TableChange.UpdateColumnType) { + TableChange.UpdateColumnType update = (TableChange.UpdateColumnType) change; + Type newType = SparkSchemaUtil.convert(update.newDataType()); + Preconditions.checkArgument( + newType.isPrimitiveType(), + "Cannot update '%s', not a primitive type: %s", + DOT.join(update.fieldNames()), + update.newDataType()); + pendingUpdate.updateColumn(DOT.join(update.fieldNames()), newType.asPrimitiveType()); + + } else if (change instanceof TableChange.UpdateColumnComment) { + TableChange.UpdateColumnComment update = (TableChange.UpdateColumnComment) change; + pendingUpdate.updateColumnDoc(DOT.join(update.fieldNames()), update.newComment()); + + } else if (change instanceof TableChange.RenameColumn) { + TableChange.RenameColumn rename = (TableChange.RenameColumn) change; + pendingUpdate.renameColumn(DOT.join(rename.fieldNames()), rename.newName()); + + } else if (change instanceof TableChange.DeleteColumn) { + TableChange.DeleteColumn delete = (TableChange.DeleteColumn) change; + pendingUpdate.deleteColumn(DOT.join(delete.fieldNames())); + + } else if (change instanceof TableChange.UpdateColumnNullability) { + TableChange.UpdateColumnNullability update = (TableChange.UpdateColumnNullability) change; + if (update.nullable()) { + pendingUpdate.makeColumnOptional(DOT.join(update.fieldNames())); + } else { + pendingUpdate.requireColumn(DOT.join(update.fieldNames())); + } + + } else if (change instanceof TableChange.UpdateColumnPosition) { + apply(pendingUpdate, (TableChange.UpdateColumnPosition) change); + + } else { + throw new UnsupportedOperationException("Cannot apply unknown table change: " + change); + } + } + + return pendingUpdate; + } + + private static void apply(UpdateSchema pendingUpdate, TableChange.UpdateColumnPosition update) { + Preconditions.checkArgument(update.position() != null, "Invalid position: null"); + + if (update.position() instanceof TableChange.After) { + TableChange.After after = (TableChange.After) update.position(); + String referenceField = peerName(update.fieldNames(), after.column()); + pendingUpdate.moveAfter(DOT.join(update.fieldNames()), referenceField); + + } else if (update.position() instanceof TableChange.First) { + pendingUpdate.moveFirst(DOT.join(update.fieldNames())); + + } else { + throw new IllegalArgumentException("Unknown position for reorder: " + update.position()); + } + } + + private static void apply(UpdateSchema pendingUpdate, TableChange.AddColumn add) { + Preconditions.checkArgument( + add.isNullable(), + "Incompatible change: cannot add required column: %s", + leafName(add.fieldNames())); + Type type = SparkSchemaUtil.convert(add.dataType()); + pendingUpdate.addColumn( + parentName(add.fieldNames()), leafName(add.fieldNames()), type, add.comment()); + + if (add.position() instanceof TableChange.After) { + TableChange.After after = (TableChange.After) add.position(); + String referenceField = peerName(add.fieldNames(), after.column()); + pendingUpdate.moveAfter(DOT.join(add.fieldNames()), referenceField); + + } else if (add.position() instanceof TableChange.First) { + pendingUpdate.moveFirst(DOT.join(add.fieldNames())); + + } else { + Preconditions.checkArgument( + add.position() == null, + "Cannot add '%s' at unknown position: %s", + DOT.join(add.fieldNames()), + add.position()); + } + } + + public static org.apache.iceberg.Table toIcebergTable(Table table) { + Preconditions.checkArgument( + table instanceof SparkTable, "Table %s is not an Iceberg table", table); + SparkTable sparkTable = (SparkTable) table; + return sparkTable.table(); + } + + public static SortOrder[] toOrdering(org.apache.iceberg.SortOrder sortOrder) { + SortOrderToSpark visitor = new SortOrderToSpark(sortOrder.schema()); + List ordering = SortOrderVisitor.visit(sortOrder, visitor); + return ordering.toArray(new SortOrder[0]); + } + + public static Transform[] toTransforms(Schema schema, List fields) { + SpecTransformToSparkTransform visitor = new SpecTransformToSparkTransform(schema); + + List transforms = Lists.newArrayList(); + + for (PartitionField field : fields) { + Transform transform = PartitionSpecVisitor.visit(schema, field, visitor); + if (transform != null) { + transforms.add(transform); + } + } + + return transforms.toArray(new Transform[0]); + } + + /** + * Converts a PartitionSpec to Spark transforms. + * + * @param spec a PartitionSpec + * @return an array of Transforms + */ + public static Transform[] toTransforms(PartitionSpec spec) { + SpecTransformToSparkTransform visitor = new SpecTransformToSparkTransform(spec.schema()); + List transforms = PartitionSpecVisitor.visit(spec, visitor); + return transforms.stream().filter(Objects::nonNull).toArray(Transform[]::new); + } + + private static class SpecTransformToSparkTransform implements PartitionSpecVisitor { + private final Map quotedNameById; + + SpecTransformToSparkTransform(Schema schema) { + this.quotedNameById = SparkSchemaUtil.indexQuotedNameById(schema); + } + + @Override + public Transform identity(String sourceName, int sourceId) { + return Expressions.identity(quotedName(sourceId)); + } + + @Override + public Transform bucket(String sourceName, int sourceId, int numBuckets) { + return Expressions.bucket(numBuckets, quotedName(sourceId)); + } + + @Override + public Transform truncate(String sourceName, int sourceId, int width) { + NamedReference column = Expressions.column(quotedName(sourceId)); + return Expressions.apply("truncate", Expressions.literal(width), column); + } + + @Override + public Transform year(String sourceName, int sourceId) { + return Expressions.years(quotedName(sourceId)); + } + + @Override + public Transform month(String sourceName, int sourceId) { + return Expressions.months(quotedName(sourceId)); + } + + @Override + public Transform day(String sourceName, int sourceId) { + return Expressions.days(quotedName(sourceId)); + } + + @Override + public Transform hour(String sourceName, int sourceId) { + return Expressions.hours(quotedName(sourceId)); + } + + @Override + public Transform alwaysNull(int fieldId, String sourceName, int sourceId) { + // do nothing for alwaysNull, it doesn't need to be converted to a transform + return null; + } + + @Override + public Transform unknown(int fieldId, String sourceName, int sourceId, String transform) { + return Expressions.apply(transform, Expressions.column(quotedName(sourceId))); + } + + private String quotedName(int id) { + return quotedNameById.get(id); + } + } + + public static NamedReference toNamedReference(String name) { + return Expressions.column(name); + } + + public static Term toIcebergTerm(Expression expr) { + if (expr instanceof Transform) { + Transform transform = (Transform) expr; + Preconditions.checkArgument( + "zorder".equals(transform.name()) || transform.references().length == 1, + "Cannot convert transform with more than one column reference: %s", + transform); + String colName = DOT.join(transform.references()[0].fieldNames()); + switch (transform.name().toLowerCase(Locale.ROOT)) { + case "identity": + return org.apache.iceberg.expressions.Expressions.ref(colName); + case "bucket": + return org.apache.iceberg.expressions.Expressions.bucket(colName, findWidth(transform)); + case "year": + case "years": + return org.apache.iceberg.expressions.Expressions.year(colName); + case "month": + case "months": + return org.apache.iceberg.expressions.Expressions.month(colName); + case "date": + case "day": + case "days": + return org.apache.iceberg.expressions.Expressions.day(colName); + case "date_hour": + case "hour": + case "hours": + return org.apache.iceberg.expressions.Expressions.hour(colName); + case "truncate": + return org.apache.iceberg.expressions.Expressions.truncate(colName, findWidth(transform)); + case "zorder": + return new Zorder( + Stream.of(transform.references()) + .map(ref -> DOT.join(ref.fieldNames())) + .map(org.apache.iceberg.expressions.Expressions::ref) + .collect(Collectors.toList())); + default: + throw new UnsupportedOperationException("Transform is not supported: " + transform); + } + + } else if (expr instanceof NamedReference) { + NamedReference ref = (NamedReference) expr; + return org.apache.iceberg.expressions.Expressions.ref(DOT.join(ref.fieldNames())); + + } else { + throw new UnsupportedOperationException("Cannot convert unknown expression: " + expr); + } + } + + /** + * Converts Spark transforms into a {@link PartitionSpec}. + * + * @param schema the table schema + * @param partitioning Spark Transforms + * @return a PartitionSpec + */ + public static PartitionSpec toPartitionSpec(Schema schema, Transform[] partitioning) { + if (partitioning == null || partitioning.length == 0) { + return PartitionSpec.unpartitioned(); + } + + PartitionSpec.Builder builder = PartitionSpec.builderFor(schema); + for (Transform transform : partitioning) { + Preconditions.checkArgument( + transform.references().length == 1, + "Cannot convert transform with more than one column reference: %s", + transform); + String colName = DOT.join(transform.references()[0].fieldNames()); + switch (transform.name().toLowerCase(Locale.ROOT)) { + case "identity": + builder.identity(colName); + break; + case "bucket": + builder.bucket(colName, findWidth(transform)); + break; + case "year": + case "years": + builder.year(colName); + break; + case "month": + case "months": + builder.month(colName); + break; + case "date": + case "day": + case "days": + builder.day(colName); + break; + case "date_hour": + case "hour": + case "hours": + builder.hour(colName); + break; + case "truncate": + builder.truncate(colName, findWidth(transform)); + break; + default: + throw new UnsupportedOperationException("Transform is not supported: " + transform); + } + } + + return builder.build(); + } + + @SuppressWarnings("unchecked") + private static int findWidth(Transform transform) { + for (Expression expr : transform.arguments()) { + if (expr instanceof Literal) { + if (((Literal) expr).dataType() instanceof IntegerType) { + Literal lit = (Literal) expr; + Preconditions.checkArgument( + lit.value() > 0, "Unsupported width for transform: %s", transform.describe()); + return lit.value(); + + } else if (((Literal) expr).dataType() instanceof LongType) { + Literal lit = (Literal) expr; + Preconditions.checkArgument( + lit.value() > 0 && lit.value() < Integer.MAX_VALUE, + "Unsupported width for transform: %s", + transform.describe()); + if (lit.value() > Integer.MAX_VALUE) { + throw new IllegalArgumentException(); + } + return lit.value().intValue(); + } + } + } + + throw new IllegalArgumentException("Cannot find width for transform: " + transform.describe()); + } + + private static String leafName(String[] fieldNames) { + Preconditions.checkArgument( + fieldNames.length > 0, "Invalid field name: at least one name is required"); + return fieldNames[fieldNames.length - 1]; + } + + private static String peerName(String[] fieldNames, String fieldName) { + if (fieldNames.length > 1) { + String[] peerNames = Arrays.copyOf(fieldNames, fieldNames.length); + peerNames[fieldNames.length - 1] = fieldName; + return DOT.join(peerNames); + } + return fieldName; + } + + private static String parentName(String[] fieldNames) { + if (fieldNames.length > 1) { + return DOT.join(Arrays.copyOfRange(fieldNames, 0, fieldNames.length - 1)); + } + return null; + } + + public static String describe(List exprs) { + return exprs.stream().map(Spark3Util::describe).collect(Collectors.joining(", ")); + } + + public static String describe(org.apache.iceberg.expressions.Expression expr) { + return ExpressionVisitors.visit(expr, DescribeExpressionVisitor.INSTANCE); + } + + public static String describe(Schema schema) { + return TypeUtil.visit(schema, DescribeSchemaVisitor.INSTANCE); + } + + public static String describe(Type type) { + return TypeUtil.visit(type, DescribeSchemaVisitor.INSTANCE); + } + + public static String describe(org.apache.iceberg.SortOrder order) { + return Joiner.on(", ").join(SortOrderVisitor.visit(order, DescribeSortOrderVisitor.INSTANCE)); + } + + public static boolean extensionsEnabled(SparkSession spark) { + String extensions = spark.conf().get("spark.sql.extensions", ""); + return extensions.contains("IcebergSparkSessionExtensions"); + } + + public static class DescribeSchemaVisitor extends TypeUtil.SchemaVisitor { + private static final Joiner COMMA = Joiner.on(','); + private static final DescribeSchemaVisitor INSTANCE = new DescribeSchemaVisitor(); + + private DescribeSchemaVisitor() {} + + @Override + public String schema(Schema schema, String structResult) { + return structResult; + } + + @Override + public String struct(Types.StructType struct, List fieldResults) { + return "struct<" + COMMA.join(fieldResults) + ">"; + } + + @Override + public String field(Types.NestedField field, String fieldResult) { + return field.name() + ": " + fieldResult + (field.isRequired() ? " not null" : ""); + } + + @Override + public String list(Types.ListType list, String elementResult) { + return "list<" + elementResult + ">"; + } + + @Override + public String map(Types.MapType map, String keyResult, String valueResult) { + return "map<" + keyResult + ", " + valueResult + ">"; + } + + @Override + public String primitive(Type.PrimitiveType primitive) { + switch (primitive.typeId()) { + case BOOLEAN: + return "boolean"; + case INTEGER: + return "int"; + case LONG: + return "bigint"; + case FLOAT: + return "float"; + case DOUBLE: + return "double"; + case DATE: + return "date"; + case TIME: + return "time"; + case TIMESTAMP: + return "timestamp"; + case STRING: + case UUID: + return "string"; + case FIXED: + case BINARY: + return "binary"; + case DECIMAL: + Types.DecimalType decimal = (Types.DecimalType) primitive; + return "decimal(" + decimal.precision() + "," + decimal.scale() + ")"; + } + throw new UnsupportedOperationException("Cannot convert type to SQL: " + primitive); + } + } + + private static class DescribeExpressionVisitor + extends ExpressionVisitors.ExpressionVisitor { + private static final DescribeExpressionVisitor INSTANCE = new DescribeExpressionVisitor(); + + private DescribeExpressionVisitor() {} + + @Override + public String alwaysTrue() { + return "true"; + } + + @Override + public String alwaysFalse() { + return "false"; + } + + @Override + public String not(String result) { + return "NOT (" + result + ")"; + } + + @Override + public String and(String leftResult, String rightResult) { + return "(" + leftResult + " AND " + rightResult + ")"; + } + + @Override + public String or(String leftResult, String rightResult) { + return "(" + leftResult + " OR " + rightResult + ")"; + } + + @Override + public String predicate(BoundPredicate pred) { + throw new UnsupportedOperationException("Cannot convert bound predicates to SQL"); + } + + @Override + public String predicate(UnboundPredicate pred) { + switch (pred.op()) { + case IS_NULL: + return sqlString(pred.term()) + " IS NULL"; + case NOT_NULL: + return sqlString(pred.term()) + " IS NOT NULL"; + case IS_NAN: + return "is_nan(" + sqlString(pred.term()) + ")"; + case NOT_NAN: + return "not_nan(" + sqlString(pred.term()) + ")"; + case LT: + return sqlString(pred.term()) + " < " + sqlString(pred.literal()); + case LT_EQ: + return sqlString(pred.term()) + " <= " + sqlString(pred.literal()); + case GT: + return sqlString(pred.term()) + " > " + sqlString(pred.literal()); + case GT_EQ: + return sqlString(pred.term()) + " >= " + sqlString(pred.literal()); + case EQ: + return sqlString(pred.term()) + " = " + sqlString(pred.literal()); + case NOT_EQ: + return sqlString(pred.term()) + " != " + sqlString(pred.literal()); + case STARTS_WITH: + return sqlString(pred.term()) + " LIKE '" + pred.literal().value() + "%'"; + case NOT_STARTS_WITH: + return sqlString(pred.term()) + " NOT LIKE '" + pred.literal().value() + "%'"; + case IN: + return sqlString(pred.term()) + " IN (" + sqlString(pred.literals()) + ")"; + case NOT_IN: + return sqlString(pred.term()) + " NOT IN (" + sqlString(pred.literals()) + ")"; + default: + throw new UnsupportedOperationException("Cannot convert predicate to SQL: " + pred); + } + } + + private static String sqlString(UnboundTerm term) { + if (term instanceof org.apache.iceberg.expressions.NamedReference) { + return term.ref().name(); + } else if (term instanceof UnboundTransform) { + UnboundTransform transform = (UnboundTransform) term; + return transform.transform().toString() + "(" + transform.ref().name() + ")"; + } else { + throw new UnsupportedOperationException("Cannot convert term to SQL: " + term); + } + } + + private static String sqlString(List> literals) { + return literals.stream() + .map(DescribeExpressionVisitor::sqlString) + .collect(Collectors.joining(", ")); + } + + private static String sqlString(org.apache.iceberg.expressions.Literal lit) { + if (lit.value() instanceof String) { + return "'" + lit.value() + "'"; + } else if (lit.value() instanceof ByteBuffer) { + byte[] bytes = ByteBuffers.toByteArray((ByteBuffer) lit.value()); + return "X'" + BaseEncoding.base16().encode(bytes) + "'"; + } else { + return lit.value().toString(); + } + } + } + + /** + * Returns an Iceberg Table by its name from a Spark V2 Catalog. If cache is enabled in {@link + * SparkCatalog}, the {@link TableOperations} of the table may be stale, please refresh the table + * to get the latest one. + * + * @param spark SparkSession used for looking up catalog references and tables + * @param name The multipart identifier of the Iceberg table + * @return an Iceberg table + */ + public static org.apache.iceberg.Table loadIcebergTable(SparkSession spark, String name) + throws ParseException, NoSuchTableException { + CatalogAndIdentifier catalogAndIdentifier = catalogAndIdentifier(spark, name); + + TableCatalog catalog = asTableCatalog(catalogAndIdentifier.catalog); + Table sparkTable = catalog.loadTable(catalogAndIdentifier.identifier); + return toIcebergTable(sparkTable); + } + + /** + * Returns the underlying Iceberg Catalog object represented by a Spark Catalog + * + * @param spark SparkSession used for looking up catalog reference + * @param catalogName The name of the Spark Catalog being referenced + * @return the Iceberg catalog class being wrapped by the Spark Catalog + */ + public static Catalog loadIcebergCatalog(SparkSession spark, String catalogName) { + CatalogPlugin catalogPlugin = spark.sessionState().catalogManager().catalog(catalogName); + Preconditions.checkArgument( + catalogPlugin instanceof HasIcebergCatalog, + String.format( + "Cannot load Iceberg catalog from catalog %s because it does not contain an Iceberg Catalog. " + + "Actual Class: %s", + catalogName, catalogPlugin.getClass().getName())); + return ((HasIcebergCatalog) catalogPlugin).icebergCatalog(); + } + + public static CatalogAndIdentifier catalogAndIdentifier(SparkSession spark, String name) + throws ParseException { + return catalogAndIdentifier( + spark, name, spark.sessionState().catalogManager().currentCatalog()); + } + + public static CatalogAndIdentifier catalogAndIdentifier( + SparkSession spark, String name, CatalogPlugin defaultCatalog) throws ParseException { + ParserInterface parser = spark.sessionState().sqlParser(); + Seq multiPartIdentifier = parser.parseMultipartIdentifier(name).toIndexedSeq(); + List javaMultiPartIdentifier = JavaConverters.seqAsJavaList(multiPartIdentifier); + return catalogAndIdentifier(spark, javaMultiPartIdentifier, defaultCatalog); + } + + public static CatalogAndIdentifier catalogAndIdentifier( + String description, SparkSession spark, String name) { + return catalogAndIdentifier( + description, spark, name, spark.sessionState().catalogManager().currentCatalog()); + } + + public static CatalogAndIdentifier catalogAndIdentifier( + String description, SparkSession spark, String name, CatalogPlugin defaultCatalog) { + try { + return catalogAndIdentifier(spark, name, defaultCatalog); + } catch (ParseException e) { + throw new IllegalArgumentException("Cannot parse " + description + ": " + name, e); + } + } + + public static CatalogAndIdentifier catalogAndIdentifier( + SparkSession spark, List nameParts) { + return catalogAndIdentifier( + spark, nameParts, spark.sessionState().catalogManager().currentCatalog()); + } + + /** + * A modified version of Spark's LookupCatalog.CatalogAndIdentifier.unapply Attempts to find the + * catalog and identifier a multipart identifier represents + * + * @param spark Spark session to use for resolution + * @param nameParts Multipart identifier representing a table + * @param defaultCatalog Catalog to use if none is specified + * @return The CatalogPlugin and Identifier for the table + */ + public static CatalogAndIdentifier catalogAndIdentifier( + SparkSession spark, List nameParts, CatalogPlugin defaultCatalog) { + CatalogManager catalogManager = spark.sessionState().catalogManager(); + + String[] currentNamespace; + if (defaultCatalog.equals(catalogManager.currentCatalog())) { + currentNamespace = catalogManager.currentNamespace(); + } else { + currentNamespace = defaultCatalog.defaultNamespace(); + } + + Pair catalogIdentifier = + SparkUtil.catalogAndIdentifier( + nameParts, + catalogName -> { + try { + return catalogManager.catalog(catalogName); + } catch (Exception e) { + return null; + } + }, + Identifier::of, + defaultCatalog, + currentNamespace); + return new CatalogAndIdentifier(catalogIdentifier); + } + + private static TableCatalog asTableCatalog(CatalogPlugin catalog) { + if (catalog instanceof TableCatalog) { + return (TableCatalog) catalog; + } + + throw new IllegalArgumentException( + String.format( + "Cannot use catalog %s(%s): not a TableCatalog", + catalog.name(), catalog.getClass().getName())); + } + + /** This mimics a class inside of Spark which is private inside of LookupCatalog. */ + public static class CatalogAndIdentifier { + private final CatalogPlugin catalog; + private final Identifier identifier; + + public CatalogAndIdentifier(CatalogPlugin catalog, Identifier identifier) { + this.catalog = catalog; + this.identifier = identifier; + } + + public CatalogAndIdentifier(Pair identifier) { + this.catalog = identifier.first(); + this.identifier = identifier.second(); + } + + public CatalogPlugin catalog() { + return catalog; + } + + public Identifier identifier() { + return identifier; + } + } + + public static TableIdentifier identifierToTableIdentifier(Identifier identifier) { + return TableIdentifier.of(Namespace.of(identifier.namespace()), identifier.name()); + } + + public static String quotedFullIdentifier(String catalogName, Identifier identifier) { + List parts = + ImmutableList.builder() + .add(catalogName) + .addAll(Arrays.asList(identifier.namespace())) + .add(identifier.name()) + .build(); + + return CatalogV2Implicits.MultipartIdentifierHelper( + JavaConverters.asScalaIteratorConverter(parts.iterator()).asScala().toSeq()) + .quoted(); + } + + /** + * Use Spark to list all partitions in the table. + * + * @param spark a Spark session + * @param rootPath a table identifier + * @param format format of the file + * @param partitionFilter partitionFilter of the file + * @param partitionSpec partitionSpec of the table + * @return all table's partitions + */ + public static List getPartitions( + SparkSession spark, + Path rootPath, + String format, + Map partitionFilter, + PartitionSpec partitionSpec) { + FileStatusCache fileStatusCache = FileStatusCache.getOrCreate(spark); + + Option userSpecifiedSchema = + partitionSpec == null + ? Option.empty() + : Option.apply( + SparkSchemaUtil.convert(new Schema(partitionSpec.partitionType().fields()))); + + InMemoryFileIndex fileIndex = + new InMemoryFileIndex( + spark, + JavaConverters.collectionAsScalaIterableConverter(ImmutableList.of(rootPath)) + .asScala() + .toSeq(), + scala.collection.immutable.Map$.MODULE$.empty(), + userSpecifiedSchema, + fileStatusCache, + Option.empty(), + Option.empty()); + + org.apache.spark.sql.execution.datasources.PartitionSpec spec = fileIndex.partitionSpec(); + StructType schema = spec.partitionColumns(); + if (schema.isEmpty()) { + return Lists.newArrayList(); + } + + List filterExpressions = + SparkUtil.partitionMapToExpression(schema, partitionFilter); + Seq scalaPartitionFilters = + JavaConverters.asScalaBufferConverter(filterExpressions).asScala().toIndexedSeq(); + + List dataFilters = Lists.newArrayList(); + Seq scalaDataFilters = + JavaConverters.asScalaBufferConverter(dataFilters).asScala().toIndexedSeq(); + + Seq filteredPartitions = + fileIndex.listFiles(scalaPartitionFilters, scalaDataFilters).toIndexedSeq(); + + return JavaConverters.seqAsJavaListConverter(filteredPartitions).asJava().stream() + .map( + partition -> { + Map values = Maps.newHashMap(); + JavaConverters.asJavaIterableConverter(schema) + .asJava() + .forEach( + field -> { + int fieldIndex = schema.fieldIndex(field.name()); + Object catalystValue = partition.values().get(fieldIndex, field.dataType()); + Object value = + CatalystTypeConverters.convertToScala(catalystValue, field.dataType()); + values.put(field.name(), String.valueOf(value)); + }); + + FileStatusWithMetadata fileStatus = + JavaConverters.seqAsJavaListConverter(partition.files()).asJava().get(0); + + return new SparkPartition( + values, fileStatus.getPath().getParent().toString(), format); + }) + .collect(Collectors.toList()); + } + + public static org.apache.spark.sql.catalyst.TableIdentifier toV1TableIdentifier( + Identifier identifier) { + String[] namespace = identifier.namespace(); + + Preconditions.checkArgument( + namespace.length <= 1, + "Cannot convert %s to a Spark v1 identifier, namespace contains more than 1 part", + identifier); + + String table = identifier.name(); + Option database = namespace.length == 1 ? Option.apply(namespace[0]) : Option.empty(); + return org.apache.spark.sql.catalyst.TableIdentifier.apply(table, database); + } + + static String baseTableUUID(org.apache.iceberg.Table table) { + if (table instanceof HasTableOperations) { + TableOperations ops = ((HasTableOperations) table).operations(); + return ops.current().uuid(); + } else if (table instanceof BaseMetadataTable) { + return ((BaseMetadataTable) table).table().operations().current().uuid(); + } else { + throw new UnsupportedOperationException("Cannot retrieve UUID for table " + table.name()); + } + } + + private static class DescribeSortOrderVisitor implements SortOrderVisitor { + private static final DescribeSortOrderVisitor INSTANCE = new DescribeSortOrderVisitor(); + + private DescribeSortOrderVisitor() {} + + @Override + public String field( + String sourceName, + int sourceId, + org.apache.iceberg.SortDirection direction, + NullOrder nullOrder) { + return String.format("%s %s %s", sourceName, direction, nullOrder); + } + + @Override + public String bucket( + String sourceName, + int sourceId, + int numBuckets, + org.apache.iceberg.SortDirection direction, + NullOrder nullOrder) { + return String.format("bucket(%s, %s) %s %s", numBuckets, sourceName, direction, nullOrder); + } + + @Override + public String truncate( + String sourceName, + int sourceId, + int width, + org.apache.iceberg.SortDirection direction, + NullOrder nullOrder) { + return String.format("truncate(%s, %s) %s %s", sourceName, width, direction, nullOrder); + } + + @Override + public String year( + String sourceName, + int sourceId, + org.apache.iceberg.SortDirection direction, + NullOrder nullOrder) { + return String.format("years(%s) %s %s", sourceName, direction, nullOrder); + } + + @Override + public String month( + String sourceName, + int sourceId, + org.apache.iceberg.SortDirection direction, + NullOrder nullOrder) { + return String.format("months(%s) %s %s", sourceName, direction, nullOrder); + } + + @Override + public String day( + String sourceName, + int sourceId, + org.apache.iceberg.SortDirection direction, + NullOrder nullOrder) { + return String.format("days(%s) %s %s", sourceName, direction, nullOrder); + } + + @Override + public String hour( + String sourceName, + int sourceId, + org.apache.iceberg.SortDirection direction, + NullOrder nullOrder) { + return String.format("hours(%s) %s %s", sourceName, direction, nullOrder); + } + + @Override + public String unknown( + String sourceName, + int sourceId, + String transform, + org.apache.iceberg.SortDirection direction, + NullOrder nullOrder) { + return String.format("%s(%s) %s %s", transform, sourceName, direction, nullOrder); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkAggregates.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkAggregates.java new file mode 100644 index 000000000000..153ef11a9eb6 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkAggregates.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Map; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expression.Operation; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc; +import org.apache.spark.sql.connector.expressions.aggregate.Count; +import org.apache.spark.sql.connector.expressions.aggregate.CountStar; +import org.apache.spark.sql.connector.expressions.aggregate.Max; +import org.apache.spark.sql.connector.expressions.aggregate.Min; + +public class SparkAggregates { + private SparkAggregates() {} + + private static final Map, Operation> AGGREGATES = + ImmutableMap., Operation>builder() + .put(Count.class, Operation.COUNT) + .put(CountStar.class, Operation.COUNT_STAR) + .put(Max.class, Operation.MAX) + .put(Min.class, Operation.MIN) + .buildOrThrow(); + + public static Expression convert(AggregateFunc aggregate) { + Operation op = AGGREGATES.get(aggregate.getClass()); + if (op != null) { + switch (op) { + case COUNT: + Count countAgg = (Count) aggregate; + if (countAgg.isDistinct()) { + // manifest file doesn't have count distinct so this can't be pushed down + return null; + } + + if (countAgg.column() instanceof NamedReference) { + return Expressions.count(SparkUtil.toColumnName((NamedReference) countAgg.column())); + } else { + return null; + } + + case COUNT_STAR: + return Expressions.countStar(); + + case MAX: + Max maxAgg = (Max) aggregate; + if (maxAgg.column() instanceof NamedReference) { + return Expressions.max(SparkUtil.toColumnName((NamedReference) maxAgg.column())); + } else { + return null; + } + + case MIN: + Min minAgg = (Min) aggregate; + if (minAgg.column() instanceof NamedReference) { + return Expressions.min(SparkUtil.toColumnName((NamedReference) minAgg.column())); + } else { + return null; + } + } + } + + return null; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkCachedTableCatalog.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkCachedTableCatalog.java new file mode 100644 index 000000000000..21317526d2aa --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkCachedTableCatalog.java @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.TimeUnit; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Stream; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.base.Splitter; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.source.SparkTable; +import org.apache.iceberg.util.Pair; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.catalog.TableChange; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +/** An internal table catalog that is capable of loading tables from a cache. */ +public class SparkCachedTableCatalog implements TableCatalog, SupportsFunctions { + + private static final String CLASS_NAME = SparkCachedTableCatalog.class.getName(); + private static final Splitter COMMA = Splitter.on(","); + private static final Pattern AT_TIMESTAMP = Pattern.compile("at_timestamp_(\\d+)"); + private static final Pattern SNAPSHOT_ID = Pattern.compile("snapshot_id_(\\d+)"); + private static final Pattern BRANCH = Pattern.compile("branch_(.*)"); + private static final Pattern TAG = Pattern.compile("tag_(.*)"); + + private static final SparkTableCache TABLE_CACHE = SparkTableCache.get(); + + private String name = null; + + @Override + public Identifier[] listTables(String[] namespace) { + throw new UnsupportedOperationException(CLASS_NAME + " does not support listing tables"); + } + + @Override + public SparkTable loadTable(Identifier ident) throws NoSuchTableException { + Pair table = load(ident); + return new SparkTable(table.first(), table.second(), false /* refresh eagerly */); + } + + @Override + public SparkTable loadTable(Identifier ident, String version) throws NoSuchTableException { + Pair table = load(ident); + Preconditions.checkArgument( + table.second() == null, "Cannot time travel based on both table identifier and AS OF"); + return new SparkTable(table.first(), Long.parseLong(version), false /* refresh eagerly */); + } + + @Override + public SparkTable loadTable(Identifier ident, long timestampMicros) throws NoSuchTableException { + Pair table = load(ident); + Preconditions.checkArgument( + table.second() == null, "Cannot time travel based on both table identifier and AS OF"); + // Spark passes microseconds but Iceberg uses milliseconds for snapshots + long timestampMillis = TimeUnit.MICROSECONDS.toMillis(timestampMicros); + long snapshotId = SnapshotUtil.snapshotIdAsOfTime(table.first(), timestampMillis); + return new SparkTable(table.first(), snapshotId, false /* refresh eagerly */); + } + + @Override + public void invalidateTable(Identifier ident) { + throw new UnsupportedOperationException(CLASS_NAME + " does not support table invalidation"); + } + + @Override + public SparkTable createTable( + Identifier ident, StructType schema, Transform[] partitions, Map properties) + throws TableAlreadyExistsException { + throw new UnsupportedOperationException(CLASS_NAME + " does not support creating tables"); + } + + @Override + public SparkTable alterTable(Identifier ident, TableChange... changes) { + throw new UnsupportedOperationException(CLASS_NAME + " does not support altering tables"); + } + + @Override + public boolean dropTable(Identifier ident) { + throw new UnsupportedOperationException(CLASS_NAME + " does not support dropping tables"); + } + + @Override + public boolean purgeTable(Identifier ident) throws UnsupportedOperationException { + throw new UnsupportedOperationException(CLASS_NAME + " does not support purging tables"); + } + + @Override + public void renameTable(Identifier oldIdent, Identifier newIdent) { + throw new UnsupportedOperationException(CLASS_NAME + " does not support renaming tables"); + } + + @Override + public void initialize(String catalogName, CaseInsensitiveStringMap options) { + this.name = catalogName; + } + + @Override + public String name() { + return name; + } + + private Pair load(Identifier ident) throws NoSuchTableException { + Preconditions.checkArgument( + ident.namespace().length == 0, CLASS_NAME + " does not support namespaces"); + + Pair> parsedIdent = parseIdent(ident); + String key = parsedIdent.first(); + List metadata = parsedIdent.second(); + + Long asOfTimestamp = null; + Long snapshotId = null; + String branch = null; + String tag = null; + for (String meta : metadata) { + Matcher timeBasedMatcher = AT_TIMESTAMP.matcher(meta); + if (timeBasedMatcher.matches()) { + asOfTimestamp = Long.parseLong(timeBasedMatcher.group(1)); + continue; + } + + Matcher snapshotBasedMatcher = SNAPSHOT_ID.matcher(meta); + if (snapshotBasedMatcher.matches()) { + snapshotId = Long.parseLong(snapshotBasedMatcher.group(1)); + continue; + } + + Matcher branchBasedMatcher = BRANCH.matcher(meta); + if (branchBasedMatcher.matches()) { + branch = branchBasedMatcher.group(1); + continue; + } + + Matcher tagBasedMatcher = TAG.matcher(meta); + if (tagBasedMatcher.matches()) { + tag = tagBasedMatcher.group(1); + } + } + + Preconditions.checkArgument( + Stream.of(snapshotId, asOfTimestamp, branch, tag).filter(Objects::nonNull).count() <= 1, + "Can specify only one of snapshot-id (%s), as-of-timestamp (%s), branch (%s), tag (%s)", + snapshotId, + asOfTimestamp, + branch, + tag); + + Table table = TABLE_CACHE.get(key); + + if (table == null) { + throw new NoSuchTableException(ident); + } + + if (snapshotId != null) { + return Pair.of(table, snapshotId); + } else if (asOfTimestamp != null) { + return Pair.of(table, SnapshotUtil.snapshotIdAsOfTime(table, asOfTimestamp)); + } else if (branch != null) { + Snapshot branchSnapshot = table.snapshot(branch); + Preconditions.checkArgument( + branchSnapshot != null, "Cannot find snapshot associated with branch name: %s", branch); + return Pair.of(table, branchSnapshot.snapshotId()); + } else if (tag != null) { + Snapshot tagSnapshot = table.snapshot(tag); + Preconditions.checkArgument( + tagSnapshot != null, "Cannot find snapshot associated with tag name: %s", tag); + return Pair.of(table, tagSnapshot.snapshotId()); + } else { + return Pair.of(table, null); + } + } + + private Pair> parseIdent(Identifier ident) { + int hashIndex = ident.name().lastIndexOf('#'); + if (hashIndex != -1 && !ident.name().endsWith("#")) { + String key = ident.name().substring(0, hashIndex); + List metadata = COMMA.splitToList(ident.name().substring(hashIndex + 1)); + return Pair.of(key, metadata); + } else { + return Pair.of(ident.name(), ImmutableList.of()); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkCatalog.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkCatalog.java new file mode 100644 index 000000000000..5eec4d102cfb --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkCatalog.java @@ -0,0 +1,1014 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.apache.iceberg.TableProperties.GC_ENABLED; +import static org.apache.iceberg.TableProperties.GC_ENABLED_DEFAULT; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.TreeMap; +import java.util.concurrent.TimeUnit; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Stream; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.CachingCatalog; +import org.apache.iceberg.CatalogProperties; +import org.apache.iceberg.CatalogUtil; +import org.apache.iceberg.EnvironmentContext; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.MetadataTableType; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotRef; +import org.apache.iceberg.Transaction; +import org.apache.iceberg.catalog.Catalog; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.SupportsNamespaces; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.catalog.ViewCatalog; +import org.apache.iceberg.exceptions.AlreadyExistsException; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.hadoop.HadoopCatalog; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.base.Joiner; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.base.Splitter; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.actions.SparkActions; +import org.apache.iceberg.spark.source.SparkChangelogTable; +import org.apache.iceberg.spark.source.SparkTable; +import org.apache.iceberg.spark.source.SparkView; +import org.apache.iceberg.spark.source.StagedSparkTable; +import org.apache.iceberg.util.Pair; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.iceberg.view.UpdateViewProperties; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.analysis.NamespaceAlreadyExistsException; +import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.analysis.NoSuchViewException; +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException; +import org.apache.spark.sql.catalyst.analysis.ViewAlreadyExistsException; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.NamespaceChange; +import org.apache.spark.sql.connector.catalog.StagedTable; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.catalog.TableChange; +import org.apache.spark.sql.connector.catalog.TableChange.ColumnChange; +import org.apache.spark.sql.connector.catalog.TableChange.RemoveProperty; +import org.apache.spark.sql.connector.catalog.TableChange.SetProperty; +import org.apache.spark.sql.connector.catalog.View; +import org.apache.spark.sql.connector.catalog.ViewChange; +import org.apache.spark.sql.connector.catalog.ViewInfo; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +/** + * A Spark TableCatalog implementation that wraps an Iceberg {@link Catalog}. + * + *

This supports the following catalog configuration options: + * + *

    + *
  • type - catalog type, "hive" or "hadoop" or "rest". To specify a non-hive or + * hadoop catalog, use the catalog-impl option. + *
  • uri - the Hive Metastore URI for Hive catalog or REST URI for REST catalog + *
  • warehouse - the warehouse path (Hadoop catalog only) + *
  • catalog-impl - a custom {@link Catalog} implementation to use + *
  • io-impl - a custom {@link org.apache.iceberg.io.FileIO} implementation to use + *
  • metrics-reporter-impl - a custom {@link + * org.apache.iceberg.metrics.MetricsReporter} implementation to use + *
  • default-namespace - a namespace to use as the default + *
  • cache-enabled - whether to enable catalog cache + *
  • cache.case-sensitive - whether the catalog cache should compare table + * identifiers in a case sensitive way + *
  • cache.expiration-interval-ms - interval in millis before expiring tables from + * catalog cache. Refer to {@link CatalogProperties#CACHE_EXPIRATION_INTERVAL_MS} for further + * details and significant values. + *
  • table-default.$tablePropertyKey - table property $tablePropertyKey default at + * catalog level + *
  • table-override.$tablePropertyKey - table property $tablePropertyKey enforced + * at catalog level + *
+ * + *

+ */ +public class SparkCatalog extends BaseCatalog + implements org.apache.spark.sql.connector.catalog.ViewCatalog, SupportsReplaceView { + private static final Set DEFAULT_NS_KEYS = ImmutableSet.of(TableCatalog.PROP_OWNER); + private static final Splitter COMMA = Splitter.on(","); + private static final Joiner COMMA_JOINER = Joiner.on(","); + private static final Pattern AT_TIMESTAMP = Pattern.compile("at_timestamp_(\\d+)"); + private static final Pattern SNAPSHOT_ID = Pattern.compile("snapshot_id_(\\d+)"); + private static final Pattern BRANCH = Pattern.compile("branch_(.*)"); + private static final Pattern TAG = Pattern.compile("tag_(.*)"); + + private String catalogName = null; + private Catalog icebergCatalog = null; + private boolean cacheEnabled = CatalogProperties.CACHE_ENABLED_DEFAULT; + private SupportsNamespaces asNamespaceCatalog = null; + private ViewCatalog asViewCatalog = null; + private String[] defaultNamespace = null; + private HadoopTables tables; + + /** + * Build an Iceberg {@link Catalog} to be used by this Spark catalog adapter. + * + * @param name Spark's catalog name + * @param options Spark's catalog options + * @return an Iceberg catalog + */ + protected Catalog buildIcebergCatalog(String name, CaseInsensitiveStringMap options) { + Configuration conf = SparkUtil.hadoopConfCatalogOverrides(SparkSession.active(), name); + Map optionsMap = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); + optionsMap.putAll(options.asCaseSensitiveMap()); + optionsMap.put(CatalogProperties.APP_ID, SparkSession.active().sparkContext().applicationId()); + optionsMap.put(CatalogProperties.USER, SparkSession.active().sparkContext().sparkUser()); + return CatalogUtil.buildIcebergCatalog(name, optionsMap, conf); + } + + /** + * Build an Iceberg {@link TableIdentifier} for the given Spark identifier. + * + * @param identifier Spark's identifier + * @return an Iceberg identifier + */ + protected TableIdentifier buildIdentifier(Identifier identifier) { + return Spark3Util.identifierToTableIdentifier(identifier); + } + + @Override + public Table loadTable(Identifier ident) throws NoSuchTableException { + try { + return load(ident); + } catch (org.apache.iceberg.exceptions.NoSuchTableException e) { + throw new NoSuchTableException(ident); + } + } + + @Override + public Table loadTable(Identifier ident, String version) throws NoSuchTableException { + Table table = loadTable(ident); + + if (table instanceof SparkTable) { + SparkTable sparkTable = (SparkTable) table; + + Preconditions.checkArgument( + sparkTable.snapshotId() == null && sparkTable.branch() == null, + "Cannot do time-travel based on both table identifier and AS OF"); + + try { + return sparkTable.copyWithSnapshotId(Long.parseLong(version)); + } catch (NumberFormatException e) { + SnapshotRef ref = sparkTable.table().refs().get(version); + ValidationException.check( + ref != null, + "Cannot find matching snapshot ID or reference name for version " + version); + + if (ref.isBranch()) { + return sparkTable.copyWithBranch(version); + } else { + return sparkTable.copyWithSnapshotId(ref.snapshotId()); + } + } + + } else if (table instanceof SparkChangelogTable) { + throw new UnsupportedOperationException("AS OF is not supported for changelogs"); + + } else { + throw new IllegalArgumentException("Unknown Spark table type: " + table.getClass().getName()); + } + } + + @Override + public Table loadTable(Identifier ident, long timestamp) throws NoSuchTableException { + Table table = loadTable(ident); + + if (table instanceof SparkTable) { + SparkTable sparkTable = (SparkTable) table; + + Preconditions.checkArgument( + sparkTable.snapshotId() == null && sparkTable.branch() == null, + "Cannot do time-travel based on both table identifier and AS OF"); + + // convert the timestamp to milliseconds as Spark passes microseconds + // but Iceberg uses milliseconds for snapshot timestamps + long timestampMillis = TimeUnit.MICROSECONDS.toMillis(timestamp); + long snapshotId = SnapshotUtil.snapshotIdAsOfTime(sparkTable.table(), timestampMillis); + return sparkTable.copyWithSnapshotId(snapshotId); + + } else if (table instanceof SparkChangelogTable) { + throw new UnsupportedOperationException("AS OF is not supported for changelogs"); + + } else { + throw new IllegalArgumentException("Unknown Spark table type: " + table.getClass().getName()); + } + } + + @Override + public Table createTable( + Identifier ident, StructType schema, Transform[] transforms, Map properties) + throws TableAlreadyExistsException { + Schema icebergSchema = SparkSchemaUtil.convert(schema); + try { + Catalog.TableBuilder builder = newBuilder(ident, icebergSchema); + org.apache.iceberg.Table icebergTable = + builder + .withPartitionSpec(Spark3Util.toPartitionSpec(icebergSchema, transforms)) + .withLocation(properties.get("location")) + .withProperties(Spark3Util.rebuildCreateProperties(properties)) + .create(); + return new SparkTable(icebergTable, !cacheEnabled); + } catch (AlreadyExistsException e) { + throw new TableAlreadyExistsException(ident); + } + } + + @Override + public StagedTable stageCreate( + Identifier ident, StructType schema, Transform[] transforms, Map properties) + throws TableAlreadyExistsException { + Schema icebergSchema = SparkSchemaUtil.convert(schema); + try { + Catalog.TableBuilder builder = newBuilder(ident, icebergSchema); + Transaction transaction = + builder + .withPartitionSpec(Spark3Util.toPartitionSpec(icebergSchema, transforms)) + .withLocation(properties.get("location")) + .withProperties(Spark3Util.rebuildCreateProperties(properties)) + .createTransaction(); + return new StagedSparkTable(transaction); + } catch (AlreadyExistsException e) { + throw new TableAlreadyExistsException(ident); + } + } + + @Override + public StagedTable stageReplace( + Identifier ident, StructType schema, Transform[] transforms, Map properties) + throws NoSuchTableException { + Schema icebergSchema = SparkSchemaUtil.convert(schema); + try { + Catalog.TableBuilder builder = newBuilder(ident, icebergSchema); + Transaction transaction = + builder + .withPartitionSpec(Spark3Util.toPartitionSpec(icebergSchema, transforms)) + .withLocation(properties.get("location")) + .withProperties(Spark3Util.rebuildCreateProperties(properties)) + .replaceTransaction(); + return new StagedSparkTable(transaction); + } catch (org.apache.iceberg.exceptions.NoSuchTableException e) { + throw new NoSuchTableException(ident); + } + } + + @Override + public StagedTable stageCreateOrReplace( + Identifier ident, StructType schema, Transform[] transforms, Map properties) { + Schema icebergSchema = SparkSchemaUtil.convert(schema); + Catalog.TableBuilder builder = newBuilder(ident, icebergSchema); + Transaction transaction = + builder + .withPartitionSpec(Spark3Util.toPartitionSpec(icebergSchema, transforms)) + .withLocation(properties.get("location")) + .withProperties(Spark3Util.rebuildCreateProperties(properties)) + .createOrReplaceTransaction(); + return new StagedSparkTable(transaction); + } + + @Override + public Table alterTable(Identifier ident, TableChange... changes) throws NoSuchTableException { + SetProperty setLocation = null; + SetProperty setSnapshotId = null; + SetProperty pickSnapshotId = null; + List propertyChanges = Lists.newArrayList(); + List schemaChanges = Lists.newArrayList(); + + for (TableChange change : changes) { + if (change instanceof SetProperty) { + SetProperty set = (SetProperty) change; + if (TableCatalog.PROP_LOCATION.equalsIgnoreCase(set.property())) { + setLocation = set; + } else if ("current-snapshot-id".equalsIgnoreCase(set.property())) { + setSnapshotId = set; + } else if ("cherry-pick-snapshot-id".equalsIgnoreCase(set.property())) { + pickSnapshotId = set; + } else if ("sort-order".equalsIgnoreCase(set.property())) { + throw new UnsupportedOperationException( + "Cannot specify the 'sort-order' because it's a reserved table " + + "property. Please use the command 'ALTER TABLE ... WRITE ORDERED BY' to specify write sort-orders."); + } else if ("identifier-fields".equalsIgnoreCase(set.property())) { + throw new UnsupportedOperationException( + "Cannot specify the 'identifier-fields' because it's a reserved table property. " + + "Please use the command 'ALTER TABLE ... SET IDENTIFIER FIELDS' to specify identifier fields."); + } else { + propertyChanges.add(set); + } + } else if (change instanceof RemoveProperty) { + propertyChanges.add(change); + } else if (change instanceof ColumnChange) { + schemaChanges.add(change); + } else { + throw new UnsupportedOperationException("Cannot apply unknown table change: " + change); + } + } + + try { + org.apache.iceberg.Table table = icebergCatalog.loadTable(buildIdentifier(ident)); + commitChanges( + table, setLocation, setSnapshotId, pickSnapshotId, propertyChanges, schemaChanges); + return new SparkTable(table, true /* refreshEagerly */); + } catch (org.apache.iceberg.exceptions.NoSuchTableException e) { + throw new NoSuchTableException(ident); + } + } + + @Override + public boolean dropTable(Identifier ident) { + return dropTableWithoutPurging(ident); + } + + @Override + public boolean purgeTable(Identifier ident) { + try { + org.apache.iceberg.Table table = icebergCatalog.loadTable(buildIdentifier(ident)); + ValidationException.check( + PropertyUtil.propertyAsBoolean(table.properties(), GC_ENABLED, GC_ENABLED_DEFAULT), + "Cannot purge table: GC is disabled (deleting files may corrupt other tables)"); + String metadataFileLocation = + ((HasTableOperations) table).operations().current().metadataFileLocation(); + + boolean dropped = dropTableWithoutPurging(ident); + + if (dropped) { + // check whether the metadata file exists because HadoopCatalog/HadoopTables + // will drop the warehouse directly and ignore the `purge` argument + boolean metadataFileExists = table.io().newInputFile(metadataFileLocation).exists(); + + if (metadataFileExists) { + SparkActions.get().deleteReachableFiles(metadataFileLocation).io(table.io()).execute(); + } + } + + return dropped; + } catch (org.apache.iceberg.exceptions.NoSuchTableException e) { + return false; + } + } + + private boolean dropTableWithoutPurging(Identifier ident) { + if (isPathIdentifier(ident)) { + return tables.dropTable(((PathIdentifier) ident).location(), false /* don't purge data */); + } else { + return icebergCatalog.dropTable(buildIdentifier(ident), false /* don't purge data */); + } + } + + @Override + public void renameTable(Identifier from, Identifier to) + throws NoSuchTableException, TableAlreadyExistsException { + try { + checkNotPathIdentifier(from, "renameTable"); + checkNotPathIdentifier(to, "renameTable"); + icebergCatalog.renameTable(buildIdentifier(from), buildIdentifier(to)); + } catch (org.apache.iceberg.exceptions.NoSuchTableException e) { + throw new NoSuchTableException(from); + } catch (AlreadyExistsException e) { + throw new TableAlreadyExistsException(to); + } + } + + @Override + public void invalidateTable(Identifier ident) { + if (!isPathIdentifier(ident)) { + icebergCatalog.invalidateTable(buildIdentifier(ident)); + } + } + + @Override + public Identifier[] listTables(String[] namespace) { + return icebergCatalog.listTables(Namespace.of(namespace)).stream() + .map(ident -> Identifier.of(ident.namespace().levels(), ident.name())) + .toArray(Identifier[]::new); + } + + @Override + public String[] defaultNamespace() { + if (defaultNamespace != null) { + return defaultNamespace; + } + + return new String[0]; + } + + @Override + public String[][] listNamespaces() { + if (asNamespaceCatalog != null) { + return asNamespaceCatalog.listNamespaces().stream() + .map(Namespace::levels) + .toArray(String[][]::new); + } + + return new String[0][]; + } + + @Override + public String[][] listNamespaces(String[] namespace) throws NoSuchNamespaceException { + if (asNamespaceCatalog != null) { + try { + return asNamespaceCatalog.listNamespaces(Namespace.of(namespace)).stream() + .map(Namespace::levels) + .toArray(String[][]::new); + } catch (org.apache.iceberg.exceptions.NoSuchNamespaceException e) { + throw new NoSuchNamespaceException(namespace); + } + } + + throw new NoSuchNamespaceException(namespace); + } + + @Override + public Map loadNamespaceMetadata(String[] namespace) + throws NoSuchNamespaceException { + if (asNamespaceCatalog != null) { + try { + return asNamespaceCatalog.loadNamespaceMetadata(Namespace.of(namespace)); + } catch (org.apache.iceberg.exceptions.NoSuchNamespaceException e) { + throw new NoSuchNamespaceException(namespace); + } + } + + throw new NoSuchNamespaceException(namespace); + } + + @Override + public void createNamespace(String[] namespace, Map metadata) + throws NamespaceAlreadyExistsException { + if (asNamespaceCatalog != null) { + try { + if (asNamespaceCatalog instanceof HadoopCatalog + && DEFAULT_NS_KEYS.equals(metadata.keySet())) { + // Hadoop catalog will reject metadata properties, but Spark automatically adds "owner". + // If only the automatic properties are present, replace metadata with an empty map. + asNamespaceCatalog.createNamespace(Namespace.of(namespace), ImmutableMap.of()); + } else { + asNamespaceCatalog.createNamespace(Namespace.of(namespace), metadata); + } + } catch (AlreadyExistsException e) { + throw new NamespaceAlreadyExistsException(namespace); + } + } else { + throw new UnsupportedOperationException( + "Namespaces are not supported by catalog: " + catalogName); + } + } + + @Override + public void alterNamespace(String[] namespace, NamespaceChange... changes) + throws NoSuchNamespaceException { + if (asNamespaceCatalog != null) { + Map updates = Maps.newHashMap(); + Set removals = Sets.newHashSet(); + for (NamespaceChange change : changes) { + if (change instanceof NamespaceChange.SetProperty) { + NamespaceChange.SetProperty set = (NamespaceChange.SetProperty) change; + updates.put(set.property(), set.value()); + } else if (change instanceof NamespaceChange.RemoveProperty) { + removals.add(((NamespaceChange.RemoveProperty) change).property()); + } else { + throw new UnsupportedOperationException( + "Cannot apply unknown namespace change: " + change); + } + } + + try { + if (!updates.isEmpty()) { + asNamespaceCatalog.setProperties(Namespace.of(namespace), updates); + } + + if (!removals.isEmpty()) { + asNamespaceCatalog.removeProperties(Namespace.of(namespace), removals); + } + + } catch (org.apache.iceberg.exceptions.NoSuchNamespaceException e) { + throw new NoSuchNamespaceException(namespace); + } + } else { + throw new NoSuchNamespaceException(namespace); + } + } + + @Override + public boolean dropNamespace(String[] namespace, boolean cascade) + throws NoSuchNamespaceException { + if (asNamespaceCatalog != null) { + try { + return asNamespaceCatalog.dropNamespace(Namespace.of(namespace)); + } catch (org.apache.iceberg.exceptions.NoSuchNamespaceException e) { + throw new NoSuchNamespaceException(namespace); + } + } + + return false; + } + + @Override + public Identifier[] listViews(String... namespace) { + if (null != asViewCatalog) { + return asViewCatalog.listViews(Namespace.of(namespace)).stream() + .map(ident -> Identifier.of(ident.namespace().levels(), ident.name())) + .toArray(Identifier[]::new); + } + + return new Identifier[0]; + } + + @Override + public View loadView(Identifier ident) throws NoSuchViewException { + if (null != asViewCatalog) { + try { + org.apache.iceberg.view.View view = asViewCatalog.loadView(buildIdentifier(ident)); + return new SparkView(catalogName, view); + } catch (org.apache.iceberg.exceptions.NoSuchViewException e) { + throw new NoSuchViewException(ident); + } + } + + throw new NoSuchViewException(ident); + } + + @Override + public View createView(ViewInfo viewInfo) + throws ViewAlreadyExistsException, NoSuchNamespaceException { + if (null != asViewCatalog && viewInfo != null) { + Identifier ident = viewInfo.ident(); + String sql = viewInfo.sql(); + String currentCatalog = viewInfo.currentCatalog(); + String[] currentNamespace = viewInfo.currentNamespace(); + StructType schema = viewInfo.schema(); + String[] queryColumnNames = viewInfo.queryColumnNames(); + Map properties = viewInfo.properties(); + Schema icebergSchema = SparkSchemaUtil.convert(schema); + + try { + Map props = + ImmutableMap.builder() + .putAll(Spark3Util.rebuildCreateProperties(properties)) + .put(SparkView.QUERY_COLUMN_NAMES, COMMA_JOINER.join(queryColumnNames)) + .buildKeepingLast(); + + org.apache.iceberg.view.View view = + asViewCatalog + .buildView(buildIdentifier(ident)) + .withDefaultCatalog(currentCatalog) + .withDefaultNamespace(Namespace.of(currentNamespace)) + .withQuery("spark", sql) + .withSchema(icebergSchema) + .withLocation(properties.get("location")) + .withProperties(props) + .create(); + return new SparkView(catalogName, view); + } catch (org.apache.iceberg.exceptions.NoSuchNamespaceException e) { + throw new NoSuchNamespaceException(currentNamespace); + } catch (AlreadyExistsException e) { + throw new ViewAlreadyExistsException(ident); + } + } + + throw new UnsupportedOperationException( + "Creating a view is not supported by catalog: " + catalogName); + } + + @Override + public View replaceView( + Identifier ident, + String sql, + String currentCatalog, + String[] currentNamespace, + StructType schema, + String[] queryColumnNames, + String[] columnAliases, + String[] columnComments, + Map properties) + throws NoSuchNamespaceException, NoSuchViewException { + if (null != asViewCatalog) { + Schema icebergSchema = SparkSchemaUtil.convert(schema); + + try { + Map props = + ImmutableMap.builder() + .putAll(Spark3Util.rebuildCreateProperties(properties)) + .put(SparkView.QUERY_COLUMN_NAMES, COMMA_JOINER.join(queryColumnNames)) + .buildKeepingLast(); + + org.apache.iceberg.view.View view = + asViewCatalog + .buildView(buildIdentifier(ident)) + .withDefaultCatalog(currentCatalog) + .withDefaultNamespace(Namespace.of(currentNamespace)) + .withQuery("spark", sql) + .withSchema(icebergSchema) + .withLocation(properties.get("location")) + .withProperties(props) + .createOrReplace(); + return new SparkView(catalogName, view); + } catch (org.apache.iceberg.exceptions.NoSuchNamespaceException e) { + throw new NoSuchNamespaceException(currentNamespace); + } catch (org.apache.iceberg.exceptions.NoSuchViewException e) { + throw new NoSuchViewException(ident); + } + } + + throw new UnsupportedOperationException( + "Replacing a view is not supported by catalog: " + catalogName); + } + + @Override + public View alterView(Identifier ident, ViewChange... changes) + throws NoSuchViewException, IllegalArgumentException { + if (null != asViewCatalog) { + try { + org.apache.iceberg.view.View view = asViewCatalog.loadView(buildIdentifier(ident)); + UpdateViewProperties updateViewProperties = view.updateProperties(); + + for (ViewChange change : changes) { + if (change instanceof ViewChange.SetProperty) { + ViewChange.SetProperty property = (ViewChange.SetProperty) change; + verifyNonReservedPropertyIsSet(property.property()); + updateViewProperties.set(property.property(), property.value()); + } else if (change instanceof ViewChange.RemoveProperty) { + ViewChange.RemoveProperty remove = (ViewChange.RemoveProperty) change; + verifyNonReservedPropertyIsUnset(remove.property()); + updateViewProperties.remove(remove.property()); + } + } + + updateViewProperties.commit(); + + return new SparkView(catalogName, view); + } catch (org.apache.iceberg.exceptions.NoSuchViewException e) { + throw new NoSuchViewException(ident); + } + } + + throw new UnsupportedOperationException( + "Altering a view is not supported by catalog: " + catalogName); + } + + private static void verifyNonReservedProperty(String property, String errorMsg) { + if (SparkView.RESERVED_PROPERTIES.contains(property)) { + throw new UnsupportedOperationException(String.format(errorMsg, property)); + } + } + + private static void verifyNonReservedPropertyIsUnset(String property) { + verifyNonReservedProperty(property, "Cannot unset reserved property: '%s'"); + } + + private static void verifyNonReservedPropertyIsSet(String property) { + verifyNonReservedProperty(property, "Cannot set reserved property: '%s'"); + } + + @Override + public boolean dropView(Identifier ident) { + if (null != asViewCatalog) { + return asViewCatalog.dropView(buildIdentifier(ident)); + } + + return false; + } + + @Override + public void renameView(Identifier fromIdentifier, Identifier toIdentifier) + throws NoSuchViewException, ViewAlreadyExistsException { + if (null != asViewCatalog) { + try { + asViewCatalog.renameView(buildIdentifier(fromIdentifier), buildIdentifier(toIdentifier)); + } catch (org.apache.iceberg.exceptions.NoSuchViewException e) { + throw new NoSuchViewException(fromIdentifier); + } catch (org.apache.iceberg.exceptions.AlreadyExistsException e) { + throw new ViewAlreadyExistsException(toIdentifier); + } + } else { + throw new UnsupportedOperationException( + "Renaming a view is not supported by catalog: " + catalogName); + } + } + + @Override + public final void initialize(String name, CaseInsensitiveStringMap options) { + super.initialize(name, options); + + this.cacheEnabled = + PropertyUtil.propertyAsBoolean( + options, CatalogProperties.CACHE_ENABLED, CatalogProperties.CACHE_ENABLED_DEFAULT); + + boolean cacheCaseSensitive = + PropertyUtil.propertyAsBoolean( + options, + CatalogProperties.CACHE_CASE_SENSITIVE, + CatalogProperties.CACHE_CASE_SENSITIVE_DEFAULT); + + long cacheExpirationIntervalMs = + PropertyUtil.propertyAsLong( + options, + CatalogProperties.CACHE_EXPIRATION_INTERVAL_MS, + CatalogProperties.CACHE_EXPIRATION_INTERVAL_MS_DEFAULT); + + // An expiration interval of 0ms effectively disables caching. + // Do not wrap with CachingCatalog. + if (cacheExpirationIntervalMs == 0) { + this.cacheEnabled = false; + } + + Catalog catalog = buildIcebergCatalog(name, options); + + this.catalogName = name; + SparkSession sparkSession = SparkSession.active(); + this.tables = + new HadoopTables(SparkUtil.hadoopConfCatalogOverrides(SparkSession.active(), name)); + this.icebergCatalog = + cacheEnabled + ? CachingCatalog.wrap(catalog, cacheCaseSensitive, cacheExpirationIntervalMs) + : catalog; + if (catalog instanceof SupportsNamespaces) { + this.asNamespaceCatalog = (SupportsNamespaces) catalog; + if (options.containsKey("default-namespace")) { + this.defaultNamespace = + Splitter.on('.').splitToList(options.get("default-namespace")).toArray(new String[0]); + } + } + + if (catalog instanceof ViewCatalog) { + this.asViewCatalog = (ViewCatalog) catalog; + } + + EnvironmentContext.put(EnvironmentContext.ENGINE_NAME, "spark"); + EnvironmentContext.put( + EnvironmentContext.ENGINE_VERSION, sparkSession.sparkContext().version()); + EnvironmentContext.put(CatalogProperties.APP_ID, sparkSession.sparkContext().applicationId()); + } + + @Override + public String name() { + return catalogName; + } + + private static void commitChanges( + org.apache.iceberg.Table table, + SetProperty setLocation, + SetProperty setSnapshotId, + SetProperty pickSnapshotId, + List propertyChanges, + List schemaChanges) { + // don't allow setting the snapshot and picking a commit at the same time because order is + // ambiguous and choosing one order leads to different results + Preconditions.checkArgument( + setSnapshotId == null || pickSnapshotId == null, + "Cannot set the current the current snapshot ID and cherry-pick snapshot changes"); + + if (setSnapshotId != null) { + long newSnapshotId = Long.parseLong(setSnapshotId.value()); + table.manageSnapshots().setCurrentSnapshot(newSnapshotId).commit(); + } + + // if updating the table snapshot, perform that update first in case it fails + if (pickSnapshotId != null) { + long newSnapshotId = Long.parseLong(pickSnapshotId.value()); + table.manageSnapshots().cherrypick(newSnapshotId).commit(); + } + + Transaction transaction = table.newTransaction(); + + if (setLocation != null) { + transaction.updateLocation().setLocation(setLocation.value()).commit(); + } + + if (!propertyChanges.isEmpty()) { + Spark3Util.applyPropertyChanges(transaction.updateProperties(), propertyChanges).commit(); + } + + if (!schemaChanges.isEmpty()) { + Spark3Util.applySchemaChanges(transaction.updateSchema(), schemaChanges).commit(); + } + + transaction.commitTransaction(); + } + + private static boolean isPathIdentifier(Identifier ident) { + return ident instanceof PathIdentifier; + } + + private static void checkNotPathIdentifier(Identifier identifier, String method) { + if (identifier instanceof PathIdentifier) { + throw new IllegalArgumentException( + String.format( + "Cannot pass path based identifier to %s method. %s is a path.", method, identifier)); + } + } + + private Table load(Identifier ident) { + if (isPathIdentifier(ident)) { + return loadFromPathIdentifier((PathIdentifier) ident); + } + + try { + org.apache.iceberg.Table table = icebergCatalog.loadTable(buildIdentifier(ident)); + return new SparkTable(table, !cacheEnabled); + + } catch (org.apache.iceberg.exceptions.NoSuchTableException e) { + if (ident.namespace().length == 0) { + throw e; + } + + // if the original load didn't work, try using the namespace as an identifier because + // the original identifier may include a snapshot selector or may point to the changelog + TableIdentifier namespaceAsIdent = buildIdentifier(namespaceToIdentifier(ident.namespace())); + org.apache.iceberg.Table table; + try { + table = icebergCatalog.loadTable(namespaceAsIdent); + } catch (Exception ignored) { + // the namespace does not identify a table, so it cannot be a table with a snapshot selector + // throw the original exception + throw e; + } + + // loading the namespace as a table worked, check the name to see if it is a valid selector + // or if the name points to the changelog + + if (ident.name().equalsIgnoreCase(SparkChangelogTable.TABLE_NAME)) { + return new SparkChangelogTable(table, !cacheEnabled); + } + + Matcher at = AT_TIMESTAMP.matcher(ident.name()); + if (at.matches()) { + long asOfTimestamp = Long.parseLong(at.group(1)); + long snapshotId = SnapshotUtil.snapshotIdAsOfTime(table, asOfTimestamp); + return new SparkTable(table, snapshotId, !cacheEnabled); + } + + Matcher id = SNAPSHOT_ID.matcher(ident.name()); + if (id.matches()) { + long snapshotId = Long.parseLong(id.group(1)); + return new SparkTable(table, snapshotId, !cacheEnabled); + } + + Matcher branch = BRANCH.matcher(ident.name()); + if (branch.matches()) { + return new SparkTable(table, branch.group(1), !cacheEnabled); + } + + Matcher tag = TAG.matcher(ident.name()); + if (tag.matches()) { + Snapshot tagSnapshot = table.snapshot(tag.group(1)); + if (tagSnapshot != null) { + return new SparkTable(table, tagSnapshot.snapshotId(), !cacheEnabled); + } + } + + // the name wasn't a valid snapshot selector and did not point to the changelog + // throw the original exception + throw e; + } + } + + private Pair> parseLocationString(String location) { + int hashIndex = location.lastIndexOf('#'); + if (hashIndex != -1 && !location.endsWith("#")) { + String baseLocation = location.substring(0, hashIndex); + List metadata = COMMA.splitToList(location.substring(hashIndex + 1)); + return Pair.of(baseLocation, metadata); + } else { + return Pair.of(location, ImmutableList.of()); + } + } + + @SuppressWarnings("CyclomaticComplexity") + private Table loadFromPathIdentifier(PathIdentifier ident) { + Pair> parsed = parseLocationString(ident.location()); + + String metadataTableName = null; + Long asOfTimestamp = null; + Long snapshotId = null; + String branch = null; + String tag = null; + boolean isChangelog = false; + + for (String meta : parsed.second()) { + if (meta.equalsIgnoreCase(SparkChangelogTable.TABLE_NAME)) { + isChangelog = true; + continue; + } + + if (MetadataTableType.from(meta) != null) { + metadataTableName = meta; + continue; + } + + Matcher at = AT_TIMESTAMP.matcher(meta); + if (at.matches()) { + asOfTimestamp = Long.parseLong(at.group(1)); + continue; + } + + Matcher id = SNAPSHOT_ID.matcher(meta); + if (id.matches()) { + snapshotId = Long.parseLong(id.group(1)); + continue; + } + + Matcher branchRef = BRANCH.matcher(meta); + if (branchRef.matches()) { + branch = branchRef.group(1); + continue; + } + + Matcher tagRef = TAG.matcher(meta); + if (tagRef.matches()) { + tag = tagRef.group(1); + } + } + + Preconditions.checkArgument( + Stream.of(snapshotId, asOfTimestamp, branch, tag).filter(Objects::nonNull).count() <= 1, + "Can specify only one of snapshot-id (%s), as-of-timestamp (%s), branch (%s), tag (%s)", + snapshotId, + asOfTimestamp, + branch, + tag); + + Preconditions.checkArgument( + !isChangelog || (snapshotId == null && asOfTimestamp == null), + "Cannot specify snapshot-id and as-of-timestamp for changelogs"); + + org.apache.iceberg.Table table = + tables.load(parsed.first() + (metadataTableName != null ? "#" + metadataTableName : "")); + + if (isChangelog) { + return new SparkChangelogTable(table, !cacheEnabled); + + } else if (asOfTimestamp != null) { + long snapshotIdAsOfTime = SnapshotUtil.snapshotIdAsOfTime(table, asOfTimestamp); + return new SparkTable(table, snapshotIdAsOfTime, !cacheEnabled); + + } else if (branch != null) { + return new SparkTable(table, branch, !cacheEnabled); + + } else if (tag != null) { + Snapshot tagSnapshot = table.snapshot(tag); + Preconditions.checkArgument( + tagSnapshot != null, "Cannot find snapshot associated with tag name: %s", tag); + return new SparkTable(table, tagSnapshot.snapshotId(), !cacheEnabled); + + } else { + return new SparkTable(table, snapshotId, !cacheEnabled); + } + } + + private Identifier namespaceToIdentifier(String[] namespace) { + Preconditions.checkArgument( + namespace.length > 0, "Cannot convert empty namespace to identifier"); + String[] ns = Arrays.copyOf(namespace, namespace.length - 1); + String name = namespace[ns.length]; + return Identifier.of(ns, name); + } + + private Catalog.TableBuilder newBuilder(Identifier ident, Schema schema) { + return isPathIdentifier(ident) + ? tables.buildTable(((PathIdentifier) ident).location(), schema) + : icebergCatalog.buildTable(buildIdentifier(ident), schema); + } + + @Override + public Catalog icebergCatalog() { + return icebergCatalog; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkCompressionUtil.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkCompressionUtil.java new file mode 100644 index 000000000000..8f00b7f8301d --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkCompressionUtil.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Locale; +import java.util.Map; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.util.Pair; +import org.apache.spark.SparkConf; +import org.apache.spark.sql.SparkSession; + +class SparkCompressionUtil { + + private static final String LZ4 = "lz4"; + private static final String ZSTD = "zstd"; + private static final String GZIP = "gzip"; + private static final String ZLIB = "zlib"; + private static final String SNAPPY = "snappy"; + private static final String NONE = "none"; + + // an internal Spark config that controls whether shuffle data is compressed + private static final String SHUFFLE_COMPRESSION_ENABLED = "spark.shuffle.compress"; + private static final boolean SHUFFLE_COMPRESSION_ENABLED_DEFAULT = true; + + // an internal Spark config that controls what compression codec is used + private static final String SPARK_COMPRESSION_CODEC = "spark.io.compression.codec"; + private static final String SPARK_COMPRESSION_CODEC_DEFAULT = "lz4"; + + private static final double DEFAULT_COLUMNAR_COMPRESSION = 2; + private static final Map, Double> COLUMNAR_COMPRESSIONS = + initColumnarCompressions(); + + private static final double DEFAULT_ROW_BASED_COMPRESSION = 1; + private static final Map, Double> ROW_BASED_COMPRESSIONS = + initRowBasedCompressions(); + + private SparkCompressionUtil() {} + + /** + * Estimates how much the data in shuffle map files will compress once it is written to disk using + * a particular file format and codec. + */ + public static double shuffleCompressionRatio( + SparkSession spark, FileFormat outputFileFormat, String outputCodec) { + if (outputFileFormat == FileFormat.ORC || outputFileFormat == FileFormat.PARQUET) { + return columnarCompression(shuffleCodec(spark), outputCodec); + } else if (outputFileFormat == FileFormat.AVRO) { + return rowBasedCompression(shuffleCodec(spark), outputCodec); + } else { + return 1.0; + } + } + + private static String shuffleCodec(SparkSession spark) { + SparkConf sparkConf = spark.sparkContext().conf(); + return shuffleCompressionEnabled(sparkConf) ? sparkCodec(sparkConf) : NONE; + } + + private static boolean shuffleCompressionEnabled(SparkConf sparkConf) { + return sparkConf.getBoolean(SHUFFLE_COMPRESSION_ENABLED, SHUFFLE_COMPRESSION_ENABLED_DEFAULT); + } + + private static String sparkCodec(SparkConf sparkConf) { + return sparkConf.get(SPARK_COMPRESSION_CODEC, SPARK_COMPRESSION_CODEC_DEFAULT); + } + + private static double columnarCompression(String shuffleCodec, String outputCodec) { + Pair key = Pair.of(normalize(shuffleCodec), normalize(outputCodec)); + return COLUMNAR_COMPRESSIONS.getOrDefault(key, DEFAULT_COLUMNAR_COMPRESSION); + } + + private static double rowBasedCompression(String shuffleCodec, String outputCodec) { + Pair key = Pair.of(normalize(shuffleCodec), normalize(outputCodec)); + return ROW_BASED_COMPRESSIONS.getOrDefault(key, DEFAULT_ROW_BASED_COMPRESSION); + } + + private static String normalize(String value) { + return value != null ? value.toLowerCase(Locale.ROOT) : null; + } + + private static Map, Double> initColumnarCompressions() { + Map, Double> compressions = Maps.newHashMap(); + + compressions.put(Pair.of(NONE, ZSTD), 4.0); + compressions.put(Pair.of(NONE, GZIP), 4.0); + compressions.put(Pair.of(NONE, ZLIB), 4.0); + compressions.put(Pair.of(NONE, SNAPPY), 3.0); + compressions.put(Pair.of(NONE, LZ4), 3.0); + + compressions.put(Pair.of(ZSTD, ZSTD), 2.0); + compressions.put(Pair.of(ZSTD, GZIP), 2.0); + compressions.put(Pair.of(ZSTD, ZLIB), 2.0); + compressions.put(Pair.of(ZSTD, SNAPPY), 1.5); + compressions.put(Pair.of(ZSTD, LZ4), 1.5); + + compressions.put(Pair.of(SNAPPY, ZSTD), 3.0); + compressions.put(Pair.of(SNAPPY, GZIP), 3.0); + compressions.put(Pair.of(SNAPPY, ZLIB), 3.0); + compressions.put(Pair.of(SNAPPY, SNAPPY), 2.0); + compressions.put(Pair.of(SNAPPY, LZ4), 2.); + + compressions.put(Pair.of(LZ4, ZSTD), 3.0); + compressions.put(Pair.of(LZ4, GZIP), 3.0); + compressions.put(Pair.of(LZ4, ZLIB), 3.0); + compressions.put(Pair.of(LZ4, SNAPPY), 2.0); + compressions.put(Pair.of(LZ4, LZ4), 2.0); + + return compressions; + } + + private static Map, Double> initRowBasedCompressions() { + Map, Double> compressions = Maps.newHashMap(); + + compressions.put(Pair.of(NONE, ZSTD), 2.0); + compressions.put(Pair.of(NONE, GZIP), 2.0); + compressions.put(Pair.of(NONE, ZLIB), 2.0); + + compressions.put(Pair.of(ZSTD, SNAPPY), 0.5); + compressions.put(Pair.of(ZSTD, LZ4), 0.5); + + compressions.put(Pair.of(SNAPPY, ZSTD), 1.5); + compressions.put(Pair.of(SNAPPY, GZIP), 1.5); + compressions.put(Pair.of(SNAPPY, ZLIB), 1.5); + + compressions.put(Pair.of(LZ4, ZSTD), 1.5); + compressions.put(Pair.of(LZ4, GZIP), 1.5); + compressions.put(Pair.of(LZ4, ZLIB), 1.5); + + return compressions; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkConfParser.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkConfParser.java new file mode 100644 index 000000000000..0e9679e14d33 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkConfParser.java @@ -0,0 +1,321 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.sql.RuntimeConfig; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.internal.RuntimeConfigImpl; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +class SparkConfParser { + + private final Map properties; + private final RuntimeConfig sessionConf; + private final CaseInsensitiveStringMap options; + + SparkConfParser() { + this.properties = ImmutableMap.of(); + this.sessionConf = new RuntimeConfigImpl(SQLConf.get()); + this.options = CaseInsensitiveStringMap.empty(); + } + + SparkConfParser(SparkSession spark, Table table, Map options) { + this.properties = table.properties(); + this.sessionConf = spark.conf(); + this.options = asCaseInsensitiveStringMap(options); + } + + public BooleanConfParser booleanConf() { + return new BooleanConfParser(); + } + + public IntConfParser intConf() { + return new IntConfParser(); + } + + public LongConfParser longConf() { + return new LongConfParser(); + } + + public StringConfParser stringConf() { + return new StringConfParser(); + } + + public DurationConfParser durationConf() { + return new DurationConfParser(); + } + + public > EnumConfParser enumConf(Function toEnum) { + return new EnumConfParser<>(toEnum); + } + + private static CaseInsensitiveStringMap asCaseInsensitiveStringMap(Map map) { + if (map instanceof CaseInsensitiveStringMap) { + return (CaseInsensitiveStringMap) map; + } else { + return new CaseInsensitiveStringMap(map); + } + } + + class BooleanConfParser extends ConfParser { + private Boolean defaultValue; + private boolean negate = false; + + @Override + protected BooleanConfParser self() { + return this; + } + + public BooleanConfParser defaultValue(boolean value) { + this.defaultValue = value; + return self(); + } + + public BooleanConfParser defaultValue(String value) { + this.defaultValue = Boolean.parseBoolean(value); + return self(); + } + + public BooleanConfParser negate() { + this.negate = true; + return self(); + } + + public boolean parse() { + Preconditions.checkArgument(defaultValue != null, "Default value cannot be null"); + boolean value = parse(Boolean::parseBoolean, defaultValue); + return negate ? !value : value; + } + } + + class IntConfParser extends ConfParser { + private Integer defaultValue; + + @Override + protected IntConfParser self() { + return this; + } + + public IntConfParser defaultValue(int value) { + this.defaultValue = value; + return self(); + } + + public int parse() { + Preconditions.checkArgument(defaultValue != null, "Default value cannot be null"); + return parse(Integer::parseInt, defaultValue); + } + + public Integer parseOptional() { + return parse(Integer::parseInt, defaultValue); + } + } + + class LongConfParser extends ConfParser { + private Long defaultValue; + + @Override + protected LongConfParser self() { + return this; + } + + public LongConfParser defaultValue(long value) { + this.defaultValue = value; + return self(); + } + + public long parse() { + Preconditions.checkArgument(defaultValue != null, "Default value cannot be null"); + return parse(Long::parseLong, defaultValue); + } + + public Long parseOptional() { + return parse(Long::parseLong, defaultValue); + } + } + + class StringConfParser extends ConfParser { + private String defaultValue; + + @Override + protected StringConfParser self() { + return this; + } + + public StringConfParser defaultValue(String value) { + this.defaultValue = value; + return self(); + } + + public String parse() { + Preconditions.checkArgument(defaultValue != null, "Default value cannot be null"); + return parse(Function.identity(), defaultValue); + } + + public String parseOptional() { + return parse(Function.identity(), defaultValue); + } + } + + class DurationConfParser extends ConfParser { + private Duration defaultValue; + + @Override + protected DurationConfParser self() { + return this; + } + + public DurationConfParser defaultValue(Duration value) { + this.defaultValue = value; + return self(); + } + + public Duration parse() { + Preconditions.checkArgument(defaultValue != null, "Default value cannot be null"); + return parse(this::toDuration, defaultValue); + } + + public Duration parseOptional() { + return parse(this::toDuration, defaultValue); + } + + private Duration toDuration(String time) { + return Duration.ofSeconds(JavaUtils.timeStringAsSec(time)); + } + } + + class EnumConfParser> extends ConfParser, T> { + private final Function toEnum; + private T defaultValue; + + EnumConfParser(Function toEnum) { + this.toEnum = toEnum; + } + + @Override + protected EnumConfParser self() { + return this; + } + + public EnumConfParser defaultValue(T value) { + this.defaultValue = value; + return self(); + } + + public EnumConfParser defaultValue(String value) { + this.defaultValue = toEnum.apply(value); + return self(); + } + + public T parse() { + Preconditions.checkArgument(defaultValue != null, "Default value cannot be null"); + return parse(toEnum, defaultValue); + } + + public T parseOptional() { + return parse(toEnum, defaultValue); + } + } + + abstract class ConfParser { + private final List optionNames = Lists.newArrayList(); + private String sessionConfName; + private String tablePropertyName; + + protected abstract ThisT self(); + + public ThisT option(String name) { + this.optionNames.add(name); + return self(); + } + + public ThisT sessionConf(String name) { + this.sessionConfName = name; + return self(); + } + + public ThisT tableProperty(String name) { + this.tablePropertyName = name; + return self(); + } + + protected T parse(Function conversion, T defaultValue) { + for (String optionName : optionNames) { + String optionValue = options.get(optionName); + if (optionValue != null) { + return conversion.apply(optionValue); + } + + String sparkOptionValue = options.get(toCamelCase(optionName)); + if (sparkOptionValue != null) { + return conversion.apply(sparkOptionValue); + } + } + + if (sessionConfName != null) { + String sessionConfValue = sessionConf.get(sessionConfName, null); + if (sessionConfValue != null) { + return conversion.apply(sessionConfValue); + } + + String sparkSessionConfValue = sessionConf.get(toCamelCase(sessionConfName), null); + if (sparkSessionConfValue != null) { + return conversion.apply(sparkSessionConfValue); + } + } + + if (tablePropertyName != null) { + String propertyValue = properties.get(tablePropertyName); + if (propertyValue != null) { + return conversion.apply(propertyValue); + } + } + + return defaultValue; + } + + private String toCamelCase(String key) { + StringBuilder transformedKey = new StringBuilder(); + boolean capitalizeNext = false; + + for (char character : key.toCharArray()) { + if (character == '-') { + capitalizeNext = true; + } else if (capitalizeNext) { + transformedKey.append(Character.toUpperCase(character)); + capitalizeNext = false; + } else { + transformedKey.append(character); + } + } + + return transformedKey.toString(); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkContentFile.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkContentFile.java new file mode 100644 index 000000000000..bad31d8d85f4 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkContentFile.java @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileContent; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.StructProjection; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.StructType; + +public abstract class SparkContentFile implements ContentFile { + + private static final FileContent[] FILE_CONTENT_VALUES = FileContent.values(); + + private final int fileContentPosition; + private final int filePathPosition; + private final int fileFormatPosition; + private final int partitionPosition; + private final int recordCountPosition; + private final int fileSizeInBytesPosition; + private final int columnSizesPosition; + private final int valueCountsPosition; + private final int nullValueCountsPosition; + private final int nanValueCountsPosition; + private final int lowerBoundsPosition; + private final int upperBoundsPosition; + private final int keyMetadataPosition; + private final int splitOffsetsPosition; + private final int sortOrderIdPosition; + private final int fileSpecIdPosition; + private final int equalityIdsPosition; + private final int referencedDataFilePosition; + private final int contentOffsetPosition; + private final int contentSizePosition; + private final Type lowerBoundsType; + private final Type upperBoundsType; + private final Type keyMetadataType; + + private final SparkStructLike wrappedPartition; + private final StructLike projectedPartition; + private Row wrapped; + + SparkContentFile(Types.StructType type, Types.StructType projectedType, StructType sparkType) { + this.lowerBoundsType = type.fieldType(DataFile.LOWER_BOUNDS.name()); + this.upperBoundsType = type.fieldType(DataFile.UPPER_BOUNDS.name()); + this.keyMetadataType = type.fieldType(DataFile.KEY_METADATA.name()); + + Types.StructType partitionType = type.fieldType(DataFile.PARTITION_NAME).asStructType(); + this.wrappedPartition = new SparkStructLike(partitionType); + + if (projectedType != null) { + Types.StructType projectedPartitionType = + projectedType.fieldType(DataFile.PARTITION_NAME).asStructType(); + StructProjection partitionProjection = + StructProjection.create(partitionType, projectedPartitionType); + this.projectedPartition = partitionProjection.wrap(wrappedPartition); + } else { + this.projectedPartition = wrappedPartition; + } + + Map positions = Maps.newHashMap(); + for (Types.NestedField field : type.fields()) { + String fieldName = field.name(); + positions.put(fieldName, fieldPosition(fieldName, sparkType)); + } + + this.fileContentPosition = positions.get(DataFile.CONTENT.name()); + this.filePathPosition = positions.get(DataFile.FILE_PATH.name()); + this.fileFormatPosition = positions.get(DataFile.FILE_FORMAT.name()); + this.partitionPosition = positions.get(DataFile.PARTITION_NAME); + this.recordCountPosition = positions.get(DataFile.RECORD_COUNT.name()); + this.fileSizeInBytesPosition = positions.get(DataFile.FILE_SIZE.name()); + this.columnSizesPosition = positions.get(DataFile.COLUMN_SIZES.name()); + this.valueCountsPosition = positions.get(DataFile.VALUE_COUNTS.name()); + this.nullValueCountsPosition = positions.get(DataFile.NULL_VALUE_COUNTS.name()); + this.nanValueCountsPosition = positions.get(DataFile.NAN_VALUE_COUNTS.name()); + this.lowerBoundsPosition = positions.get(DataFile.LOWER_BOUNDS.name()); + this.upperBoundsPosition = positions.get(DataFile.UPPER_BOUNDS.name()); + this.keyMetadataPosition = positions.get(DataFile.KEY_METADATA.name()); + this.splitOffsetsPosition = positions.get(DataFile.SPLIT_OFFSETS.name()); + this.sortOrderIdPosition = positions.get(DataFile.SORT_ORDER_ID.name()); + this.fileSpecIdPosition = positions.get(DataFile.SPEC_ID.name()); + this.equalityIdsPosition = positions.get(DataFile.EQUALITY_IDS.name()); + this.referencedDataFilePosition = positions.get(DataFile.REFERENCED_DATA_FILE.name()); + this.contentOffsetPosition = positions.get(DataFile.CONTENT_OFFSET.name()); + this.contentSizePosition = positions.get(DataFile.CONTENT_SIZE.name()); + } + + public F wrap(Row row) { + this.wrapped = row; + if (wrappedPartition.size() > 0) { + wrappedPartition.wrap(row.getAs(partitionPosition)); + } + return asFile(); + } + + protected abstract F asFile(); + + @Override + public Long pos() { + return null; + } + + @Override + public int specId() { + if (wrapped.isNullAt(fileSpecIdPosition)) { + return -1; + } + return wrapped.getAs(fileSpecIdPosition); + } + + @Override + public FileContent content() { + if (wrapped.isNullAt(fileContentPosition)) { + return null; + } + return FILE_CONTENT_VALUES[wrapped.getInt(fileContentPosition)]; + } + + @Override + public CharSequence path() { + return wrapped.getAs(filePathPosition); + } + + @Override + public FileFormat format() { + return FileFormat.fromString(wrapped.getString(fileFormatPosition)); + } + + @Override + public StructLike partition() { + return projectedPartition; + } + + @Override + public long recordCount() { + return wrapped.getAs(recordCountPosition); + } + + @Override + public long fileSizeInBytes() { + return wrapped.getAs(fileSizeInBytesPosition); + } + + @Override + public Map columnSizes() { + return wrapped.isNullAt(columnSizesPosition) ? null : wrapped.getJavaMap(columnSizesPosition); + } + + @Override + public Map valueCounts() { + return wrapped.isNullAt(valueCountsPosition) ? null : wrapped.getJavaMap(valueCountsPosition); + } + + @Override + public Map nullValueCounts() { + if (wrapped.isNullAt(nullValueCountsPosition)) { + return null; + } + return wrapped.getJavaMap(nullValueCountsPosition); + } + + @Override + public Map nanValueCounts() { + if (wrapped.isNullAt(nanValueCountsPosition)) { + return null; + } + return wrapped.getJavaMap(nanValueCountsPosition); + } + + @Override + public Map lowerBounds() { + Map lowerBounds = + wrapped.isNullAt(lowerBoundsPosition) ? null : wrapped.getJavaMap(lowerBoundsPosition); + return convert(lowerBoundsType, lowerBounds); + } + + @Override + public Map upperBounds() { + Map upperBounds = + wrapped.isNullAt(upperBoundsPosition) ? null : wrapped.getJavaMap(upperBoundsPosition); + return convert(upperBoundsType, upperBounds); + } + + @Override + public ByteBuffer keyMetadata() { + return convert(keyMetadataType, wrapped.get(keyMetadataPosition)); + } + + @Override + public F copy() { + throw new UnsupportedOperationException("Not implemented: copy"); + } + + @Override + public F copyWithoutStats() { + throw new UnsupportedOperationException("Not implemented: copyWithoutStats"); + } + + @Override + public List splitOffsets() { + return wrapped.isNullAt(splitOffsetsPosition) ? null : wrapped.getList(splitOffsetsPosition); + } + + @Override + public Integer sortOrderId() { + return wrapped.getAs(sortOrderIdPosition); + } + + @Override + public List equalityFieldIds() { + return wrapped.isNullAt(equalityIdsPosition) ? null : wrapped.getList(equalityIdsPosition); + } + + public String referencedDataFile() { + if (wrapped.isNullAt(referencedDataFilePosition)) { + return null; + } + return wrapped.getString(referencedDataFilePosition); + } + + public Long contentOffset() { + if (wrapped.isNullAt(contentOffsetPosition)) { + return null; + } + return wrapped.getLong(contentOffsetPosition); + } + + public Long contentSizeInBytes() { + if (wrapped.isNullAt(contentSizePosition)) { + return null; + } + return wrapped.getLong(contentSizePosition); + } + + private int fieldPosition(String name, StructType sparkType) { + try { + return sparkType.fieldIndex(name); + } catch (IllegalArgumentException e) { + // the partition field is absent for unpartitioned tables + if (name.equals(DataFile.PARTITION_NAME) && wrappedPartition.size() == 0) { + return -1; + } + throw e; + } + } + + @SuppressWarnings("unchecked") + private T convert(Type valueType, Object value) { + return (T) SparkValueConverter.convert(valueType, value); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkDataFile.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkDataFile.java new file mode 100644 index 000000000000..543ebf3f9ea7 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkDataFile.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.List; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.types.StructType; + +public class SparkDataFile extends SparkContentFile implements DataFile { + + public SparkDataFile(Types.StructType type, StructType sparkType) { + super(type, null, sparkType); + } + + public SparkDataFile( + Types.StructType type, Types.StructType projectedType, StructType sparkType) { + super(type, projectedType, sparkType); + } + + @Override + protected DataFile asFile() { + return this; + } + + @Override + public List equalityFieldIds() { + return null; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkDeleteFile.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkDeleteFile.java new file mode 100644 index 000000000000..6250a1630683 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkDeleteFile.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.types.StructType; + +public class SparkDeleteFile extends SparkContentFile implements DeleteFile { + + public SparkDeleteFile(Types.StructType type, StructType sparkType) { + super(type, null, sparkType); + } + + public SparkDeleteFile( + Types.StructType type, Types.StructType projectedType, StructType sparkType) { + super(type, projectedType, sparkType); + } + + @Override + protected DeleteFile asFile() { + return this; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkExceptionUtil.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkExceptionUtil.java new file mode 100644 index 000000000000..5c6fe3e0ff96 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkExceptionUtil.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import com.google.errorprone.annotations.FormatMethod; +import java.io.IOException; +import org.apache.iceberg.exceptions.NoSuchNamespaceException; +import org.apache.iceberg.exceptions.NoSuchTableException; +import org.apache.iceberg.exceptions.RuntimeIOException; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.spark.sql.AnalysisException; + +public class SparkExceptionUtil { + + private SparkExceptionUtil() {} + + /** + * Converts checked exceptions to unchecked exceptions. + * + * @param cause a checked exception object which is to be converted to its unchecked equivalent. + * @param message exception message as a format string + * @param args format specifiers + * @return unchecked exception. + */ + @FormatMethod + public static RuntimeException toUncheckedException( + final Throwable cause, final String message, final Object... args) { + // Parameters are required to be final to help @FormatMethod do static analysis + if (cause instanceof RuntimeException) { + return (RuntimeException) cause; + + } else if (cause instanceof org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException) { + return new NoSuchNamespaceException(cause, message, args); + + } else if (cause instanceof org.apache.spark.sql.catalyst.analysis.NoSuchTableException) { + return new NoSuchTableException(cause, message, args); + + } else if (cause instanceof AnalysisException) { + return new ValidationException(cause, message, args); + + } else if (cause instanceof IOException) { + return new RuntimeIOException((IOException) cause, message, args); + + } else { + return new RuntimeException(String.format(message, args), cause); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkExecutorCache.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkExecutorCache.java new file mode 100644 index 000000000000..6490d6678b46 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkExecutorCache.java @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import com.github.benmanes.caffeine.cache.Cache; +import com.github.benmanes.caffeine.cache.Caffeine; +import java.time.Duration; +import java.util.List; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * An executor cache for reducing the computation and IO overhead in tasks. + * + *

The cache is configured and controlled through Spark SQL properties. It supports both limits + * on the total cache size and maximum size for individual entries. Additionally, it implements + * automatic eviction of entries after a specified duration of inactivity. The cache will respect + * the SQL configuration valid at the time of initialization. All subsequent changes to the + * configuration will have no effect. + * + *

The cache is accessed and populated via {@link #getOrLoad(String, String, Supplier, long)}. If + * the value is not present in the cache, it is computed using the provided supplier and stored in + * the cache, subject to the defined size constraints. When a key is added, it must be associated + * with a particular group ID. Once the group is no longer needed, it is recommended to explicitly + * invalidate its state by calling {@link #invalidate(String)} instead of relying on automatic + * eviction. + * + *

Note that this class employs the singleton pattern to ensure only one cache exists per JVM. + */ +public class SparkExecutorCache { + + private static final Logger LOG = LoggerFactory.getLogger(SparkExecutorCache.class); + + private static volatile SparkExecutorCache instance = null; + + private final Duration timeout; + private final long maxEntrySize; + private final long maxTotalSize; + private volatile Cache state; + + private SparkExecutorCache(Conf conf) { + this.timeout = conf.timeout(); + this.maxEntrySize = conf.maxEntrySize(); + this.maxTotalSize = conf.maxTotalSize(); + } + + /** + * Returns the cache if created or creates and returns it. + * + *

Note this method returns null if caching is disabled. + */ + public static SparkExecutorCache getOrCreate() { + if (instance == null) { + Conf conf = new Conf(); + if (conf.cacheEnabled()) { + synchronized (SparkExecutorCache.class) { + if (instance == null) { + SparkExecutorCache.instance = new SparkExecutorCache(conf); + } + } + } + } + + return instance; + } + + /** Returns the cache if already created or null otherwise. */ + public static SparkExecutorCache get() { + return instance; + } + + /** Returns the max entry size in bytes that will be considered for caching. */ + public long maxEntrySize() { + return maxEntrySize; + } + + /** + * Gets the cached value for the key or populates the cache with a new mapping. + * + * @param group a group ID + * @param key a cache key + * @param valueSupplier a supplier to compute the value + * @param valueSize an estimated memory size of the value in bytes + * @return the cached or computed value + */ + public V getOrLoad(String group, String key, Supplier valueSupplier, long valueSize) { + if (valueSize > maxEntrySize) { + LOG.debug("{} exceeds max entry size: {} > {}", key, valueSize, maxEntrySize); + return valueSupplier.get(); + } + + String internalKey = group + "_" + key; + CacheValue value = state().get(internalKey, loadFunc(valueSupplier, valueSize)); + Preconditions.checkNotNull(value, "Loaded value must not be null"); + return value.get(); + } + + private Function loadFunc(Supplier valueSupplier, long valueSize) { + return key -> { + long start = System.currentTimeMillis(); + V value = valueSupplier.get(); + long end = System.currentTimeMillis(); + LOG.debug("Loaded {} with size {} in {} ms", key, valueSize, (end - start)); + return new CacheValue(value, valueSize); + }; + } + + /** + * Invalidates all keys associated with the given group ID. + * + * @param group a group ID + */ + public void invalidate(String group) { + if (state != null) { + List internalKeys = findInternalKeys(group); + LOG.info("Invalidating {} keys associated with {}", internalKeys.size(), group); + internalKeys.forEach(internalKey -> state.invalidate(internalKey)); + LOG.info("Current cache stats {}", state.stats()); + } + } + + private List findInternalKeys(String group) { + return state.asMap().keySet().stream() + .filter(internalKey -> internalKey.startsWith(group)) + .collect(Collectors.toList()); + } + + private Cache state() { + if (state == null) { + synchronized (this) { + if (state == null) { + LOG.info("Initializing cache state"); + this.state = initState(); + } + } + } + + return state; + } + + private Cache initState() { + return Caffeine.newBuilder() + .expireAfterAccess(timeout) + .maximumWeight(maxTotalSize) + .weigher((key, value) -> ((CacheValue) value).weight()) + .recordStats() + .removalListener((key, value, cause) -> LOG.debug("Evicted {} ({})", key, cause)) + .build(); + } + + @VisibleForTesting + static class CacheValue { + private final Object value; + private final long size; + + CacheValue(Object value, long size) { + this.value = value; + this.size = size; + } + + @SuppressWarnings("unchecked") + public V get() { + return (V) value; + } + + public int weight() { + return (int) Math.min(size, Integer.MAX_VALUE); + } + } + + @VisibleForTesting + static class Conf { + private final SparkConfParser confParser = new SparkConfParser(); + + public boolean cacheEnabled() { + return confParser + .booleanConf() + .sessionConf(SparkSQLProperties.EXECUTOR_CACHE_ENABLED) + .defaultValue(SparkSQLProperties.EXECUTOR_CACHE_ENABLED_DEFAULT) + .parse(); + } + + public Duration timeout() { + return confParser + .durationConf() + .sessionConf(SparkSQLProperties.EXECUTOR_CACHE_TIMEOUT) + .defaultValue(SparkSQLProperties.EXECUTOR_CACHE_TIMEOUT_DEFAULT) + .parse(); + } + + public long maxEntrySize() { + return confParser + .longConf() + .sessionConf(SparkSQLProperties.EXECUTOR_CACHE_MAX_ENTRY_SIZE) + .defaultValue(SparkSQLProperties.EXECUTOR_CACHE_MAX_ENTRY_SIZE_DEFAULT) + .parse(); + } + + public long maxTotalSize() { + return confParser + .longConf() + .sessionConf(SparkSQLProperties.EXECUTOR_CACHE_MAX_TOTAL_SIZE) + .defaultValue(SparkSQLProperties.EXECUTOR_CACHE_MAX_TOTAL_SIZE_DEFAULT) + .parse(); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkFilters.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkFilters.java new file mode 100644 index 000000000000..49b73c7b01af --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkFilters.java @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.apache.iceberg.expressions.Expressions.and; +import static org.apache.iceberg.expressions.Expressions.equal; +import static org.apache.iceberg.expressions.Expressions.greaterThan; +import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual; +import static org.apache.iceberg.expressions.Expressions.in; +import static org.apache.iceberg.expressions.Expressions.isNaN; +import static org.apache.iceberg.expressions.Expressions.isNull; +import static org.apache.iceberg.expressions.Expressions.lessThan; +import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual; +import static org.apache.iceberg.expressions.Expressions.not; +import static org.apache.iceberg.expressions.Expressions.notIn; +import static org.apache.iceberg.expressions.Expressions.notNull; +import static org.apache.iceberg.expressions.Expressions.or; +import static org.apache.iceberg.expressions.Expressions.startsWith; + +import java.sql.Date; +import java.sql.Timestamp; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.util.Map; +import java.util.Objects; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expression.Operation; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.util.NaNUtil; +import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.apache.spark.sql.sources.AlwaysFalse; +import org.apache.spark.sql.sources.AlwaysFalse$; +import org.apache.spark.sql.sources.AlwaysTrue; +import org.apache.spark.sql.sources.AlwaysTrue$; +import org.apache.spark.sql.sources.And; +import org.apache.spark.sql.sources.EqualNullSafe; +import org.apache.spark.sql.sources.EqualTo; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.GreaterThanOrEqual; +import org.apache.spark.sql.sources.In; +import org.apache.spark.sql.sources.IsNotNull; +import org.apache.spark.sql.sources.IsNull; +import org.apache.spark.sql.sources.LessThan; +import org.apache.spark.sql.sources.LessThanOrEqual; +import org.apache.spark.sql.sources.Not; +import org.apache.spark.sql.sources.Or; +import org.apache.spark.sql.sources.StringStartsWith; + +public class SparkFilters { + + private static final Pattern BACKTICKS_PATTERN = Pattern.compile("([`])(.|$)"); + + private SparkFilters() {} + + private static final Map, Operation> FILTERS = + ImmutableMap., Operation>builder() + .put(AlwaysTrue.class, Operation.TRUE) + .put(AlwaysTrue$.class, Operation.TRUE) + .put(AlwaysFalse$.class, Operation.FALSE) + .put(AlwaysFalse.class, Operation.FALSE) + .put(EqualTo.class, Operation.EQ) + .put(EqualNullSafe.class, Operation.EQ) + .put(GreaterThan.class, Operation.GT) + .put(GreaterThanOrEqual.class, Operation.GT_EQ) + .put(LessThan.class, Operation.LT) + .put(LessThanOrEqual.class, Operation.LT_EQ) + .put(In.class, Operation.IN) + .put(IsNull.class, Operation.IS_NULL) + .put(IsNotNull.class, Operation.NOT_NULL) + .put(And.class, Operation.AND) + .put(Or.class, Operation.OR) + .put(Not.class, Operation.NOT) + .put(StringStartsWith.class, Operation.STARTS_WITH) + .buildOrThrow(); + + public static Expression convert(Filter[] filters) { + Expression expression = Expressions.alwaysTrue(); + for (Filter filter : filters) { + Expression converted = convert(filter); + Preconditions.checkArgument( + converted != null, "Cannot convert filter to Iceberg: %s", filter); + expression = Expressions.and(expression, converted); + } + return expression; + } + + public static Expression convert(Filter filter) { + // avoid using a chain of if instanceof statements by mapping to the expression enum. + Operation op = FILTERS.get(filter.getClass()); + if (op != null) { + switch (op) { + case TRUE: + return Expressions.alwaysTrue(); + + case FALSE: + return Expressions.alwaysFalse(); + + case IS_NULL: + IsNull isNullFilter = (IsNull) filter; + return isNull(unquote(isNullFilter.attribute())); + + case NOT_NULL: + IsNotNull notNullFilter = (IsNotNull) filter; + return notNull(unquote(notNullFilter.attribute())); + + case LT: + LessThan lt = (LessThan) filter; + return lessThan(unquote(lt.attribute()), convertLiteral(lt.value())); + + case LT_EQ: + LessThanOrEqual ltEq = (LessThanOrEqual) filter; + return lessThanOrEqual(unquote(ltEq.attribute()), convertLiteral(ltEq.value())); + + case GT: + GreaterThan gt = (GreaterThan) filter; + return greaterThan(unquote(gt.attribute()), convertLiteral(gt.value())); + + case GT_EQ: + GreaterThanOrEqual gtEq = (GreaterThanOrEqual) filter; + return greaterThanOrEqual(unquote(gtEq.attribute()), convertLiteral(gtEq.value())); + + case EQ: // used for both eq and null-safe-eq + if (filter instanceof EqualTo) { + EqualTo eq = (EqualTo) filter; + // comparison with null in normal equality is always null. this is probably a mistake. + Preconditions.checkNotNull( + eq.value(), "Expression is always false (eq is not null-safe): %s", filter); + return handleEqual(unquote(eq.attribute()), eq.value()); + } else { + EqualNullSafe eq = (EqualNullSafe) filter; + if (eq.value() == null) { + return isNull(unquote(eq.attribute())); + } else { + return handleEqual(unquote(eq.attribute()), eq.value()); + } + } + + case IN: + In inFilter = (In) filter; + return in( + unquote(inFilter.attribute()), + Stream.of(inFilter.values()) + .filter(Objects::nonNull) + .map(SparkFilters::convertLiteral) + .collect(Collectors.toList())); + + case NOT: + Not notFilter = (Not) filter; + Filter childFilter = notFilter.child(); + Operation childOp = FILTERS.get(childFilter.getClass()); + if (childOp == Operation.IN) { + // infer an extra notNull predicate for Spark NOT IN filters + // as Iceberg expressions don't follow the 3-value SQL boolean logic + // col NOT IN (1, 2) in Spark is equivalent to notNull(col) && notIn(col, 1, 2) in + // Iceberg + In childInFilter = (In) childFilter; + Expression notIn = + notIn( + unquote(childInFilter.attribute()), + Stream.of(childInFilter.values()) + .map(SparkFilters::convertLiteral) + .collect(Collectors.toList())); + return and(notNull(childInFilter.attribute()), notIn); + } else if (hasNoInFilter(childFilter)) { + Expression child = convert(childFilter); + if (child != null) { + return not(child); + } + } + return null; + + case AND: + { + And andFilter = (And) filter; + Expression left = convert(andFilter.left()); + Expression right = convert(andFilter.right()); + if (left != null && right != null) { + return and(left, right); + } + return null; + } + + case OR: + { + Or orFilter = (Or) filter; + Expression left = convert(orFilter.left()); + Expression right = convert(orFilter.right()); + if (left != null && right != null) { + return or(left, right); + } + return null; + } + + case STARTS_WITH: + { + StringStartsWith stringStartsWith = (StringStartsWith) filter; + return startsWith(unquote(stringStartsWith.attribute()), stringStartsWith.value()); + } + } + } + + return null; + } + + private static Object convertLiteral(Object value) { + if (value instanceof Timestamp) { + return DateTimeUtils.fromJavaTimestamp((Timestamp) value); + } else if (value instanceof Date) { + return DateTimeUtils.fromJavaDate((Date) value); + } else if (value instanceof Instant) { + return DateTimeUtils.instantToMicros((Instant) value); + } else if (value instanceof LocalDateTime) { + return DateTimeUtils.localDateTimeToMicros((LocalDateTime) value); + } else if (value instanceof LocalDate) { + return DateTimeUtils.localDateToDays((LocalDate) value); + } + return value; + } + + private static Expression handleEqual(String attribute, Object value) { + if (NaNUtil.isNaN(value)) { + return isNaN(attribute); + } else { + return equal(attribute, convertLiteral(value)); + } + } + + private static String unquote(String attributeName) { + Matcher matcher = BACKTICKS_PATTERN.matcher(attributeName); + return matcher.replaceAll("$2"); + } + + private static boolean hasNoInFilter(Filter filter) { + Operation op = FILTERS.get(filter.getClass()); + + if (op != null) { + switch (op) { + case AND: + And andFilter = (And) filter; + return hasNoInFilter(andFilter.left()) && hasNoInFilter(andFilter.right()); + case OR: + Or orFilter = (Or) filter; + return hasNoInFilter(orFilter.left()) && hasNoInFilter(orFilter.right()); + case NOT: + Not notFilter = (Not) filter; + return hasNoInFilter(notFilter.child()); + case IN: + return false; + default: + return true; + } + } + + return false; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkFixupTypes.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkFixupTypes.java new file mode 100644 index 000000000000..6c4ec39b20f1 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkFixupTypes.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import org.apache.iceberg.Schema; +import org.apache.iceberg.types.FixupTypes; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.TypeUtil; + +/** + * Some types, like binary and fixed, are converted to the same Spark type. Conversion back can + * produce only one, which may not be correct. + */ +class SparkFixupTypes extends FixupTypes { + + private SparkFixupTypes(Schema referenceSchema) { + super(referenceSchema); + } + + static Schema fixup(Schema schema, Schema referenceSchema) { + return new Schema( + TypeUtil.visit(schema, new SparkFixupTypes(referenceSchema)).asStructType().fields()); + } + + @Override + protected boolean fixupPrimitive(Type.PrimitiveType type, Type source) { + switch (type.typeId()) { + case STRING: + if (source.typeId() == Type.TypeID.UUID) { + return true; + } + break; + case BINARY: + if (source.typeId() == Type.TypeID.FIXED) { + return true; + } + break; + case TIMESTAMP: + if (source.typeId() == Type.TypeID.TIMESTAMP) { + return true; + } + break; + default: + } + return false; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkFunctionCatalog.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkFunctionCatalog.java new file mode 100644 index 000000000000..2183b9e5df4d --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkFunctionCatalog.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +/** + * A function catalog that can be used to resolve Iceberg functions without a metastore connection. + */ +public class SparkFunctionCatalog implements SupportsFunctions { + + private static final SparkFunctionCatalog INSTANCE = new SparkFunctionCatalog(); + + private String name = "iceberg-function-catalog"; + + public static SparkFunctionCatalog get() { + return INSTANCE; + } + + @Override + public void initialize(String catalogName, CaseInsensitiveStringMap options) { + this.name = catalogName; + } + + @Override + public String name() { + return name; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java new file mode 100644 index 000000000000..67e9d78ada4d --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java @@ -0,0 +1,358 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.apache.iceberg.PlanningMode.LOCAL; + +import java.util.Map; +import org.apache.iceberg.PlanningMode; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.hadoop.Util; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.spark.SparkConf; +import org.apache.spark.sql.SparkSession; + +/** + * A class for common Iceberg configs for Spark reads. + * + *

If a config is set at multiple levels, the following order of precedence is used (top to + * bottom): + * + *

    + *
  1. Read options + *
  2. Session configuration + *
  3. Table metadata + *
+ * + * The most specific value is set in read options and takes precedence over all other configs. If no + * read option is provided, this class checks the session configuration for any overrides. If no + * applicable value is found in the session configuration, this class uses the table metadata. + * + *

Note this class is NOT meant to be serialized and sent to executors. + */ +public class SparkReadConf { + + private static final String DRIVER_MAX_RESULT_SIZE = "spark.driver.maxResultSize"; + private static final String DRIVER_MAX_RESULT_SIZE_DEFAULT = "1G"; + private static final long DISTRIBUTED_PLANNING_MIN_RESULT_SIZE = 256L * 1024 * 1024; // 256 MB + + private final SparkSession spark; + private final Table table; + private final String branch; + private final SparkConfParser confParser; + + public SparkReadConf(SparkSession spark, Table table, Map readOptions) { + this(spark, table, null, readOptions); + } + + public SparkReadConf( + SparkSession spark, Table table, String branch, Map readOptions) { + this.spark = spark; + this.table = table; + this.branch = branch; + this.confParser = new SparkConfParser(spark, table, readOptions); + } + + public boolean caseSensitive() { + return SparkUtil.caseSensitive(spark); + } + + public boolean localityEnabled() { + boolean defaultValue = Util.mayHaveBlockLocations(table.io(), table.location()); + return confParser + .booleanConf() + .option(SparkReadOptions.LOCALITY) + .sessionConf(SparkSQLProperties.LOCALITY) + .defaultValue(defaultValue) + .parse(); + } + + public Long snapshotId() { + return confParser.longConf().option(SparkReadOptions.SNAPSHOT_ID).parseOptional(); + } + + public Long asOfTimestamp() { + return confParser.longConf().option(SparkReadOptions.AS_OF_TIMESTAMP).parseOptional(); + } + + public Long startSnapshotId() { + return confParser.longConf().option(SparkReadOptions.START_SNAPSHOT_ID).parseOptional(); + } + + public Long endSnapshotId() { + return confParser.longConf().option(SparkReadOptions.END_SNAPSHOT_ID).parseOptional(); + } + + public String branch() { + String optionBranch = confParser.stringConf().option(SparkReadOptions.BRANCH).parseOptional(); + ValidationException.check( + branch == null || optionBranch == null || optionBranch.equals(branch), + "Must not specify different branches in both table identifier and read option, " + + "got [%s] in identifier and [%s] in options", + branch, + optionBranch); + String inputBranch = branch != null ? branch : optionBranch; + if (inputBranch != null) { + return inputBranch; + } + + boolean wapEnabled = + PropertyUtil.propertyAsBoolean( + table.properties(), TableProperties.WRITE_AUDIT_PUBLISH_ENABLED, false); + if (wapEnabled) { + String wapBranch = spark.conf().get(SparkSQLProperties.WAP_BRANCH, null); + if (wapBranch != null && table.refs().containsKey(wapBranch)) { + return wapBranch; + } + } + + return null; + } + + public String tag() { + return confParser.stringConf().option(SparkReadOptions.TAG).parseOptional(); + } + + public String scanTaskSetId() { + return confParser.stringConf().option(SparkReadOptions.SCAN_TASK_SET_ID).parseOptional(); + } + + public boolean streamingSkipDeleteSnapshots() { + return confParser + .booleanConf() + .option(SparkReadOptions.STREAMING_SKIP_DELETE_SNAPSHOTS) + .defaultValue(SparkReadOptions.STREAMING_SKIP_DELETE_SNAPSHOTS_DEFAULT) + .parse(); + } + + public boolean streamingSkipOverwriteSnapshots() { + return confParser + .booleanConf() + .option(SparkReadOptions.STREAMING_SKIP_OVERWRITE_SNAPSHOTS) + .defaultValue(SparkReadOptions.STREAMING_SKIP_OVERWRITE_SNAPSHOTS_DEFAULT) + .parse(); + } + + public boolean parquetVectorizationEnabled() { + return confParser + .booleanConf() + .option(SparkReadOptions.VECTORIZATION_ENABLED) + .sessionConf(SparkSQLProperties.VECTORIZATION_ENABLED) + .tableProperty(TableProperties.PARQUET_VECTORIZATION_ENABLED) + .defaultValue(TableProperties.PARQUET_VECTORIZATION_ENABLED_DEFAULT) + .parse(); + } + + public int parquetBatchSize() { + return confParser + .intConf() + .option(SparkReadOptions.VECTORIZATION_BATCH_SIZE) + .tableProperty(TableProperties.PARQUET_BATCH_SIZE) + .defaultValue(TableProperties.PARQUET_BATCH_SIZE_DEFAULT) + .parse(); + } + + public boolean orcVectorizationEnabled() { + return confParser + .booleanConf() + .option(SparkReadOptions.VECTORIZATION_ENABLED) + .sessionConf(SparkSQLProperties.VECTORIZATION_ENABLED) + .tableProperty(TableProperties.ORC_VECTORIZATION_ENABLED) + .defaultValue(TableProperties.ORC_VECTORIZATION_ENABLED_DEFAULT) + .parse(); + } + + public int orcBatchSize() { + return confParser + .intConf() + .option(SparkReadOptions.VECTORIZATION_BATCH_SIZE) + .tableProperty(TableProperties.ORC_BATCH_SIZE) + .defaultValue(TableProperties.ORC_BATCH_SIZE_DEFAULT) + .parse(); + } + + public Long splitSizeOption() { + return confParser.longConf().option(SparkReadOptions.SPLIT_SIZE).parseOptional(); + } + + public long splitSize() { + return confParser + .longConf() + .option(SparkReadOptions.SPLIT_SIZE) + .tableProperty(TableProperties.SPLIT_SIZE) + .defaultValue(TableProperties.SPLIT_SIZE_DEFAULT) + .parse(); + } + + public Integer splitLookbackOption() { + return confParser.intConf().option(SparkReadOptions.LOOKBACK).parseOptional(); + } + + public int splitLookback() { + return confParser + .intConf() + .option(SparkReadOptions.LOOKBACK) + .tableProperty(TableProperties.SPLIT_LOOKBACK) + .defaultValue(TableProperties.SPLIT_LOOKBACK_DEFAULT) + .parse(); + } + + public Long splitOpenFileCostOption() { + return confParser.longConf().option(SparkReadOptions.FILE_OPEN_COST).parseOptional(); + } + + public long splitOpenFileCost() { + return confParser + .longConf() + .option(SparkReadOptions.FILE_OPEN_COST) + .tableProperty(TableProperties.SPLIT_OPEN_FILE_COST) + .defaultValue(TableProperties.SPLIT_OPEN_FILE_COST_DEFAULT) + .parse(); + } + + public long streamFromTimestamp() { + return confParser + .longConf() + .option(SparkReadOptions.STREAM_FROM_TIMESTAMP) + .defaultValue(Long.MIN_VALUE) + .parse(); + } + + public Long startTimestamp() { + return confParser.longConf().option(SparkReadOptions.START_TIMESTAMP).parseOptional(); + } + + public Long endTimestamp() { + return confParser.longConf().option(SparkReadOptions.END_TIMESTAMP).parseOptional(); + } + + public int maxFilesPerMicroBatch() { + return confParser + .intConf() + .option(SparkReadOptions.STREAMING_MAX_FILES_PER_MICRO_BATCH) + .defaultValue(Integer.MAX_VALUE) + .parse(); + } + + public int maxRecordsPerMicroBatch() { + return confParser + .intConf() + .option(SparkReadOptions.STREAMING_MAX_ROWS_PER_MICRO_BATCH) + .defaultValue(Integer.MAX_VALUE) + .parse(); + } + + public boolean preserveDataGrouping() { + return confParser + .booleanConf() + .sessionConf(SparkSQLProperties.PRESERVE_DATA_GROUPING) + .defaultValue(SparkSQLProperties.PRESERVE_DATA_GROUPING_DEFAULT) + .parse(); + } + + public boolean aggregatePushDownEnabled() { + return confParser + .booleanConf() + .option(SparkReadOptions.AGGREGATE_PUSH_DOWN_ENABLED) + .sessionConf(SparkSQLProperties.AGGREGATE_PUSH_DOWN_ENABLED) + .defaultValue(SparkSQLProperties.AGGREGATE_PUSH_DOWN_ENABLED_DEFAULT) + .parse(); + } + + public boolean adaptiveSplitSizeEnabled() { + return confParser + .booleanConf() + .tableProperty(TableProperties.ADAPTIVE_SPLIT_SIZE_ENABLED) + .defaultValue(TableProperties.ADAPTIVE_SPLIT_SIZE_ENABLED_DEFAULT) + .parse(); + } + + public int parallelism() { + int defaultParallelism = spark.sparkContext().defaultParallelism(); + int numShufflePartitions = spark.sessionState().conf().numShufflePartitions(); + return Math.max(defaultParallelism, numShufflePartitions); + } + + public boolean distributedPlanningEnabled() { + return dataPlanningMode() != LOCAL || deletePlanningMode() != LOCAL; + } + + public PlanningMode dataPlanningMode() { + if (driverMaxResultSize() < DISTRIBUTED_PLANNING_MIN_RESULT_SIZE) { + return LOCAL; + } + + return confParser + .enumConf(PlanningMode::fromName) + .sessionConf(SparkSQLProperties.DATA_PLANNING_MODE) + .tableProperty(TableProperties.DATA_PLANNING_MODE) + .defaultValue(TableProperties.PLANNING_MODE_DEFAULT) + .parse(); + } + + public PlanningMode deletePlanningMode() { + if (driverMaxResultSize() < DISTRIBUTED_PLANNING_MIN_RESULT_SIZE) { + return LOCAL; + } + + String modeName = + confParser + .stringConf() + .sessionConf(SparkSQLProperties.DELETE_PLANNING_MODE) + .tableProperty(TableProperties.DELETE_PLANNING_MODE) + .defaultValue(TableProperties.PLANNING_MODE_DEFAULT) + .parse(); + return PlanningMode.fromName(modeName); + } + + private long driverMaxResultSize() { + SparkConf sparkConf = spark.sparkContext().conf(); + return sparkConf.getSizeAsBytes(DRIVER_MAX_RESULT_SIZE, DRIVER_MAX_RESULT_SIZE_DEFAULT); + } + + public boolean executorCacheLocalityEnabled() { + return executorCacheEnabled() && executorCacheLocalityEnabledInternal(); + } + + private boolean executorCacheEnabled() { + return confParser + .booleanConf() + .sessionConf(SparkSQLProperties.EXECUTOR_CACHE_ENABLED) + .defaultValue(SparkSQLProperties.EXECUTOR_CACHE_ENABLED_DEFAULT) + .parse(); + } + + private boolean executorCacheLocalityEnabledInternal() { + return confParser + .booleanConf() + .sessionConf(SparkSQLProperties.EXECUTOR_CACHE_LOCALITY_ENABLED) + .defaultValue(SparkSQLProperties.EXECUTOR_CACHE_LOCALITY_ENABLED_DEFAULT) + .parse(); + } + + public boolean reportColumnStats() { + return confParser + .booleanConf() + .sessionConf(SparkSQLProperties.REPORT_COLUMN_STATS) + .defaultValue(SparkSQLProperties.REPORT_COLUMN_STATS_DEFAULT) + .parse(); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkReadOptions.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkReadOptions.java new file mode 100644 index 000000000000..17f2bfee69b8 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkReadOptions.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +/** Spark DF read options */ +public class SparkReadOptions { + + private SparkReadOptions() {} + + // Snapshot ID of the table snapshot to read + public static final String SNAPSHOT_ID = "snapshot-id"; + + // Start snapshot ID used in incremental scans (exclusive) + public static final String START_SNAPSHOT_ID = "start-snapshot-id"; + + // End snapshot ID used in incremental scans (inclusive) + public static final String END_SNAPSHOT_ID = "end-snapshot-id"; + + // Start timestamp used in multi-snapshot scans (exclusive) + public static final String START_TIMESTAMP = "start-timestamp"; + + // End timestamp used in multi-snapshot scans (inclusive) + public static final String END_TIMESTAMP = "end-timestamp"; + + // A timestamp in milliseconds; the snapshot used will be the snapshot current at this time. + public static final String AS_OF_TIMESTAMP = "as-of-timestamp"; + + // Branch to read from + public static final String BRANCH = "branch"; + + // Tag to read from + public static final String TAG = "tag"; + + // Overrides the table's read.split.target-size and read.split.metadata-target-size + public static final String SPLIT_SIZE = "split-size"; + + // Overrides the table's read.split.planning-lookback + public static final String LOOKBACK = "lookback"; + + // Overrides the table's read.split.open-file-cost + public static final String FILE_OPEN_COST = "file-open-cost"; + + // Overrides table's vectorization enabled properties + public static final String VECTORIZATION_ENABLED = "vectorization-enabled"; + + // Overrides the table's read.parquet.vectorization.batch-size + public static final String VECTORIZATION_BATCH_SIZE = "batch-size"; + + // Set ID that is used to fetch scan tasks + public static final String SCAN_TASK_SET_ID = "scan-task-set-id"; + + // skip snapshots of type delete while reading stream out of iceberg table + public static final String STREAMING_SKIP_DELETE_SNAPSHOTS = "streaming-skip-delete-snapshots"; + public static final boolean STREAMING_SKIP_DELETE_SNAPSHOTS_DEFAULT = false; + + // skip snapshots of type overwrite while reading stream out of iceberg table + public static final String STREAMING_SKIP_OVERWRITE_SNAPSHOTS = + "streaming-skip-overwrite-snapshots"; + public static final boolean STREAMING_SKIP_OVERWRITE_SNAPSHOTS_DEFAULT = false; + + // Controls whether to report locality information to Spark while allocating input partitions + public static final String LOCALITY = "locality"; + + // Timestamp in milliseconds; start a stream from the snapshot that occurs after this timestamp + public static final String STREAM_FROM_TIMESTAMP = "stream-from-timestamp"; + + // maximum file per micro_batch + public static final String STREAMING_MAX_FILES_PER_MICRO_BATCH = + "streaming-max-files-per-micro-batch"; + // maximum rows per micro_batch + public static final String STREAMING_MAX_ROWS_PER_MICRO_BATCH = + "streaming-max-rows-per-micro-batch"; + + // Table path + public static final String PATH = "path"; + + public static final String VERSION_AS_OF = "versionAsOf"; + + public static final String TIMESTAMP_AS_OF = "timestampAsOf"; + + public static final String AGGREGATE_PUSH_DOWN_ENABLED = "aggregate-push-down-enabled"; +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java new file mode 100644 index 000000000000..9130e63ba97e --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.time.Duration; + +public class SparkSQLProperties { + + private SparkSQLProperties() {} + + // Controls whether vectorized reads are enabled + public static final String VECTORIZATION_ENABLED = "spark.sql.iceberg.vectorization.enabled"; + + // Controls whether to perform the nullability check during writes + public static final String CHECK_NULLABILITY = "spark.sql.iceberg.check-nullability"; + public static final boolean CHECK_NULLABILITY_DEFAULT = true; + + // Controls whether to check the order of fields during writes + public static final String CHECK_ORDERING = "spark.sql.iceberg.check-ordering"; + public static final boolean CHECK_ORDERING_DEFAULT = true; + + // Controls whether to preserve the existing grouping of data while planning splits + public static final String PRESERVE_DATA_GROUPING = + "spark.sql.iceberg.planning.preserve-data-grouping"; + public static final boolean PRESERVE_DATA_GROUPING_DEFAULT = false; + + // Controls whether to push down aggregate (MAX/MIN/COUNT) to Iceberg + public static final String AGGREGATE_PUSH_DOWN_ENABLED = + "spark.sql.iceberg.aggregate-push-down.enabled"; + public static final boolean AGGREGATE_PUSH_DOWN_ENABLED_DEFAULT = true; + + // Controls write distribution mode + public static final String DISTRIBUTION_MODE = "spark.sql.iceberg.distribution-mode"; + + // Controls the WAP ID used for write-audit-publish workflow. + // When set, new snapshots will be staged with this ID in snapshot summary. + public static final String WAP_ID = "spark.wap.id"; + + // Controls the WAP branch used for write-audit-publish workflow. + // When set, new snapshots will be committed to this branch. + public static final String WAP_BRANCH = "spark.wap.branch"; + + // Controls write compress options + public static final String COMPRESSION_CODEC = "spark.sql.iceberg.compression-codec"; + public static final String COMPRESSION_LEVEL = "spark.sql.iceberg.compression-level"; + public static final String COMPRESSION_STRATEGY = "spark.sql.iceberg.compression-strategy"; + + // Overrides the data planning mode + public static final String DATA_PLANNING_MODE = "spark.sql.iceberg.data-planning-mode"; + + // Overrides the delete planning mode + public static final String DELETE_PLANNING_MODE = "spark.sql.iceberg.delete-planning-mode"; + + // Overrides the advisory partition size + public static final String ADVISORY_PARTITION_SIZE = "spark.sql.iceberg.advisory-partition-size"; + + // Controls whether to report locality information to Spark while allocating input partitions + public static final String LOCALITY = "spark.sql.iceberg.locality.enabled"; + + public static final String EXECUTOR_CACHE_ENABLED = "spark.sql.iceberg.executor-cache.enabled"; + public static final boolean EXECUTOR_CACHE_ENABLED_DEFAULT = true; + + public static final String EXECUTOR_CACHE_TIMEOUT = "spark.sql.iceberg.executor-cache.timeout"; + public static final Duration EXECUTOR_CACHE_TIMEOUT_DEFAULT = Duration.ofMinutes(10); + + public static final String EXECUTOR_CACHE_MAX_ENTRY_SIZE = + "spark.sql.iceberg.executor-cache.max-entry-size"; + public static final long EXECUTOR_CACHE_MAX_ENTRY_SIZE_DEFAULT = 64 * 1024 * 1024; // 64 MB + + public static final String EXECUTOR_CACHE_MAX_TOTAL_SIZE = + "spark.sql.iceberg.executor-cache.max-total-size"; + public static final long EXECUTOR_CACHE_MAX_TOTAL_SIZE_DEFAULT = 128 * 1024 * 1024; // 128 MB + + // Controls whether to merge schema during write operation + public static final String MERGE_SCHEMA = "spark.sql.iceberg.merge-schema"; + public static final boolean MERGE_SCHEMA_DEFAULT = false; + + public static final String EXECUTOR_CACHE_LOCALITY_ENABLED = + "spark.sql.iceberg.executor-cache.locality.enabled"; + public static final boolean EXECUTOR_CACHE_LOCALITY_ENABLED_DEFAULT = false; + + // Controls whether to report available column statistics to Spark for query optimization. + public static final String REPORT_COLUMN_STATS = "spark.sql.iceberg.report-column-stats"; + public static final boolean REPORT_COLUMN_STATS_DEFAULT = true; +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSchemaUtil.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSchemaUtil.java new file mode 100644 index 000000000000..d0f77bcdd9cc --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSchemaUtil.java @@ -0,0 +1,368 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.Binder; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.relocated.com.google.common.base.Splitter; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.math.LongMath; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalog.Column; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.StructType; + +/** Helper methods for working with Spark/Hive metadata. */ +public class SparkSchemaUtil { + private SparkSchemaUtil() {} + + /** + * Returns a {@link Schema} for the given table with fresh field ids. + * + *

This creates a Schema for an existing table by looking up the table's schema with Spark and + * converting that schema. Spark/Hive partition columns are included in the schema. + * + * @param spark a Spark session + * @param name a table name and (optional) database + * @return a Schema for the table, if found + */ + public static Schema schemaForTable(SparkSession spark, String name) { + return convert(spark.table(name).schema()); + } + + /** + * Returns a {@link PartitionSpec} for the given table. + * + *

This creates a partition spec for an existing table by looking up the table's schema and + * creating a spec with identity partitions for each partition column. + * + * @param spark a Spark session + * @param name a table name and (optional) database + * @return a PartitionSpec for the table + * @throws AnalysisException if thrown by the Spark catalog + */ + public static PartitionSpec specForTable(SparkSession spark, String name) + throws AnalysisException { + List parts = Lists.newArrayList(Splitter.on('.').limit(2).split(name)); + String db = parts.size() == 1 ? "default" : parts.get(0); + String table = parts.get(parts.size() == 1 ? 0 : 1); + + PartitionSpec spec = + identitySpec( + schemaForTable(spark, name), spark.catalog().listColumns(db, table).collectAsList()); + return spec == null ? PartitionSpec.unpartitioned() : spec; + } + + /** + * Convert a {@link Schema} to a {@link DataType Spark type}. + * + * @param schema a Schema + * @return the equivalent Spark type + * @throws IllegalArgumentException if the type cannot be converted to Spark + */ + public static StructType convert(Schema schema) { + return (StructType) TypeUtil.visit(schema, new TypeToSparkType()); + } + + /** + * Convert a {@link Type} to a {@link DataType Spark type}. + * + * @param type a Type + * @return the equivalent Spark type + * @throws IllegalArgumentException if the type cannot be converted to Spark + */ + public static DataType convert(Type type) { + return TypeUtil.visit(type, new TypeToSparkType()); + } + + /** + * Convert a Spark {@link StructType struct} to a {@link Schema} with new field ids. + * + *

This conversion assigns fresh ids. + * + *

Some data types are represented as the same Spark type. These are converted to a default + * type. + * + *

To convert using a reference schema for field ids and ambiguous types, use {@link + * #convert(Schema, StructType)}. + * + * @param sparkType a Spark StructType + * @return the equivalent Schema + * @throws IllegalArgumentException if the type cannot be converted + */ + public static Schema convert(StructType sparkType) { + Type converted = SparkTypeVisitor.visit(sparkType, new SparkTypeToType(sparkType)); + return new Schema(converted.asNestedType().asStructType().fields()); + } + + /** + * Convert a Spark {@link DataType struct} to a {@link Type} with new field ids. + * + *

This conversion assigns fresh ids. + * + *

Some data types are represented as the same Spark type. These are converted to a default + * type. + * + *

To convert using a reference schema for field ids and ambiguous types, use {@link + * #convert(Schema, StructType)}. + * + * @param sparkType a Spark DataType + * @return the equivalent Type + * @throws IllegalArgumentException if the type cannot be converted + */ + public static Type convert(DataType sparkType) { + return SparkTypeVisitor.visit(sparkType, new SparkTypeToType()); + } + + /** + * Convert a Spark {@link StructType struct} to a {@link Schema} based on the given schema. + * + *

This conversion does not assign new ids; it uses ids from the base schema. + * + *

Data types, field order, and nullability will match the spark type. This conversion may + * return a schema that is not compatible with base schema. + * + * @param baseSchema a Schema on which conversion is based + * @param sparkType a Spark StructType + * @return the equivalent Schema + * @throws IllegalArgumentException if the type cannot be converted or there are missing ids + */ + public static Schema convert(Schema baseSchema, StructType sparkType) { + return convert(baseSchema, sparkType, true); + } + + /** + * Convert a Spark {@link StructType struct} to a {@link Schema} based on the given schema. + * + *

This conversion does not assign new ids; it uses ids from the base schema. + * + *

Data types, field order, and nullability will match the spark type. This conversion may + * return a schema that is not compatible with base schema. + * + * @param baseSchema a Schema on which conversion is based + * @param sparkType a Spark StructType + * @param caseSensitive when false, the case of schema fields is ignored + * @return the equivalent Schema + * @throws IllegalArgumentException if the type cannot be converted or there are missing ids + */ + public static Schema convert(Schema baseSchema, StructType sparkType, boolean caseSensitive) { + // convert to a type with fresh ids + Types.StructType struct = + SparkTypeVisitor.visit(sparkType, new SparkTypeToType(sparkType)).asStructType(); + // reassign ids to match the base schema + Schema schema = TypeUtil.reassignIds(new Schema(struct.fields()), baseSchema, caseSensitive); + // fix types that can't be represented in Spark (UUID and Fixed) + return SparkFixupTypes.fixup(schema, baseSchema); + } + + /** + * Convert a Spark {@link StructType struct} to a {@link Schema} based on the given schema. + * + *

This conversion will assign new ids for fields that are not found in the base schema. + * + *

Data types, field order, and nullability will match the spark type. This conversion may + * return a schema that is not compatible with base schema. + * + * @param baseSchema a Schema on which conversion is based + * @param sparkType a Spark StructType + * @return the equivalent Schema + * @throws IllegalArgumentException if the type cannot be converted or there are missing ids + */ + public static Schema convertWithFreshIds(Schema baseSchema, StructType sparkType) { + return convertWithFreshIds(baseSchema, sparkType, true); + } + + /** + * Convert a Spark {@link StructType struct} to a {@link Schema} based on the given schema. + * + *

This conversion will assign new ids for fields that are not found in the base schema. + * + *

Data types, field order, and nullability will match the spark type. This conversion may + * return a schema that is not compatible with base schema. + * + * @param baseSchema a Schema on which conversion is based + * @param sparkType a Spark StructType + * @param caseSensitive when false, case of field names in schema is ignored + * @return the equivalent Schema + * @throws IllegalArgumentException if the type cannot be converted or there are missing ids + */ + public static Schema convertWithFreshIds( + Schema baseSchema, StructType sparkType, boolean caseSensitive) { + // convert to a type with fresh ids + Types.StructType struct = + SparkTypeVisitor.visit(sparkType, new SparkTypeToType(sparkType)).asStructType(); + // reassign ids to match the base schema + Schema schema = + TypeUtil.reassignOrRefreshIds(new Schema(struct.fields()), baseSchema, caseSensitive); + // fix types that can't be represented in Spark (UUID and Fixed) + return SparkFixupTypes.fixup(schema, baseSchema); + } + + /** + * Prune columns from a {@link Schema} using a {@link StructType Spark type} projection. + * + *

This requires that the Spark type is a projection of the Schema. Nullability and types must + * match. + * + * @param schema a Schema + * @param requestedType a projection of the Spark representation of the Schema + * @return a Schema corresponding to the Spark projection + * @throws IllegalArgumentException if the Spark type does not match the Schema + */ + public static Schema prune(Schema schema, StructType requestedType) { + return new Schema( + TypeUtil.visit(schema, new PruneColumnsWithoutReordering(requestedType, ImmutableSet.of())) + .asNestedType() + .asStructType() + .fields()); + } + + /** + * Prune columns from a {@link Schema} using a {@link StructType Spark type} projection. + * + *

This requires that the Spark type is a projection of the Schema. Nullability and types must + * match. + * + *

The filters list of {@link Expression} is used to ensure that columns referenced by filters + * are projected. + * + * @param schema a Schema + * @param requestedType a projection of the Spark representation of the Schema + * @param filters a list of filters + * @return a Schema corresponding to the Spark projection + * @throws IllegalArgumentException if the Spark type does not match the Schema + */ + public static Schema prune(Schema schema, StructType requestedType, List filters) { + Set filterRefs = Binder.boundReferences(schema.asStruct(), filters, true); + return new Schema( + TypeUtil.visit(schema, new PruneColumnsWithoutReordering(requestedType, filterRefs)) + .asNestedType() + .asStructType() + .fields()); + } + + /** + * Prune columns from a {@link Schema} using a {@link StructType Spark type} projection. + * + *

This requires that the Spark type is a projection of the Schema. Nullability and types must + * match. + * + *

The filters list of {@link Expression} is used to ensure that columns referenced by filters + * are projected. + * + * @param schema a Schema + * @param requestedType a projection of the Spark representation of the Schema + * @param filter a filters + * @return a Schema corresponding to the Spark projection + * @throws IllegalArgumentException if the Spark type does not match the Schema + */ + public static Schema prune( + Schema schema, StructType requestedType, Expression filter, boolean caseSensitive) { + Set filterRefs = + Binder.boundReferences(schema.asStruct(), Collections.singletonList(filter), caseSensitive); + + return new Schema( + TypeUtil.visit(schema, new PruneColumnsWithoutReordering(requestedType, filterRefs)) + .asNestedType() + .asStructType() + .fields()); + } + + private static PartitionSpec identitySpec(Schema schema, Collection columns) { + List names = Lists.newArrayList(); + for (Column column : columns) { + if (column.isPartition()) { + names.add(column.name()); + } + } + + return identitySpec(schema, names); + } + + private static PartitionSpec identitySpec(Schema schema, List partitionNames) { + if (partitionNames == null || partitionNames.isEmpty()) { + return null; + } + + PartitionSpec.Builder builder = PartitionSpec.builderFor(schema); + for (String partitionName : partitionNames) { + builder.identity(partitionName); + } + + return builder.build(); + } + + /** + * Estimate approximate table size based on Spark schema and total records. + * + * @param tableSchema Spark schema + * @param totalRecords total records in the table + * @return approximate size based on table schema + */ + public static long estimateSize(StructType tableSchema, long totalRecords) { + if (totalRecords == Long.MAX_VALUE) { + return totalRecords; + } + + long result; + try { + result = LongMath.checkedMultiply(tableSchema.defaultSize(), totalRecords); + } catch (ArithmeticException e) { + result = Long.MAX_VALUE; + } + return result; + } + + public static void validateMetadataColumnReferences(Schema tableSchema, Schema readSchema) { + List conflictingColumnNames = + readSchema.columns().stream() + .map(Types.NestedField::name) + .filter( + name -> + MetadataColumns.isMetadataColumn(name) && tableSchema.findField(name) != null) + .collect(Collectors.toList()); + + ValidationException.check( + conflictingColumnNames.isEmpty(), + "Table column names conflict with names reserved for Iceberg metadata columns: %s.\n" + + "Please, use ALTER TABLE statements to rename the conflicting table columns.", + conflictingColumnNames); + } + + public static Map indexQuotedNameById(Schema schema) { + Function quotingFunc = name -> String.format("`%s`", name.replace("`", "``")); + return TypeUtil.indexQuotedNameById(schema.asStruct(), quotingFunc); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSessionCatalog.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSessionCatalog.java new file mode 100644 index 000000000000..fa3f1fbe4b2a --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSessionCatalog.java @@ -0,0 +1,400 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.iceberg.CatalogProperties; +import org.apache.iceberg.CatalogUtil; +import org.apache.iceberg.catalog.Catalog; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.source.HasIcebergCatalog; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.analysis.NamespaceAlreadyExistsException; +import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException; +import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException; +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException; +import org.apache.spark.sql.connector.catalog.CatalogExtension; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.FunctionCatalog; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.NamespaceChange; +import org.apache.spark.sql.connector.catalog.StagedTable; +import org.apache.spark.sql.connector.catalog.StagingTableCatalog; +import org.apache.spark.sql.connector.catalog.SupportsNamespaces; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.catalog.TableChange; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +/** + * A Spark catalog that can also load non-Iceberg tables. + * + * @param CatalogPlugin class to avoid casting to TableCatalog, FunctionCatalog and + * SupportsNamespaces. + */ +public class SparkSessionCatalog + extends BaseCatalog implements CatalogExtension { + private static final String[] DEFAULT_NAMESPACE = new String[] {"default"}; + + private String catalogName = null; + private TableCatalog icebergCatalog = null; + private StagingTableCatalog asStagingCatalog = null; + private T sessionCatalog = null; + private boolean createParquetAsIceberg = false; + private boolean createAvroAsIceberg = false; + private boolean createOrcAsIceberg = false; + + /** + * Build a {@link SparkCatalog} to be used for Iceberg operations. + * + *

The default implementation creates a new SparkCatalog with the session catalog's name and + * options. + * + * @param name catalog name + * @param options catalog options + * @return a SparkCatalog to be used for Iceberg tables + */ + protected TableCatalog buildSparkCatalog(String name, CaseInsensitiveStringMap options) { + SparkCatalog newCatalog = new SparkCatalog(); + newCatalog.initialize(name, options); + return newCatalog; + } + + @Override + public String[] defaultNamespace() { + return DEFAULT_NAMESPACE; + } + + @Override + public String[][] listNamespaces() throws NoSuchNamespaceException { + return getSessionCatalog().listNamespaces(); + } + + @Override + public String[][] listNamespaces(String[] namespace) throws NoSuchNamespaceException { + return getSessionCatalog().listNamespaces(namespace); + } + + @Override + public boolean namespaceExists(String[] namespace) { + return getSessionCatalog().namespaceExists(namespace); + } + + @Override + public Map loadNamespaceMetadata(String[] namespace) + throws NoSuchNamespaceException { + return getSessionCatalog().loadNamespaceMetadata(namespace); + } + + @Override + public void createNamespace(String[] namespace, Map metadata) + throws NamespaceAlreadyExistsException { + getSessionCatalog().createNamespace(namespace, metadata); + } + + @Override + public void alterNamespace(String[] namespace, NamespaceChange... changes) + throws NoSuchNamespaceException { + getSessionCatalog().alterNamespace(namespace, changes); + } + + @Override + public boolean dropNamespace(String[] namespace, boolean cascade) + throws NoSuchNamespaceException, NonEmptyNamespaceException { + return getSessionCatalog().dropNamespace(namespace, cascade); + } + + @Override + public Identifier[] listTables(String[] namespace) throws NoSuchNamespaceException { + // delegate to the session catalog because all tables share the same namespace + return getSessionCatalog().listTables(namespace); + } + + @Override + public Table loadTable(Identifier ident) throws NoSuchTableException { + try { + return icebergCatalog.loadTable(ident); + } catch (NoSuchTableException e) { + return getSessionCatalog().loadTable(ident); + } + } + + @Override + public Table loadTable(Identifier ident, String version) throws NoSuchTableException { + try { + return icebergCatalog.loadTable(ident, version); + } catch (org.apache.iceberg.exceptions.NoSuchTableException e) { + return getSessionCatalog().loadTable(ident, version); + } + } + + @Override + public Table loadTable(Identifier ident, long timestamp) throws NoSuchTableException { + try { + return icebergCatalog.loadTable(ident, timestamp); + } catch (org.apache.iceberg.exceptions.NoSuchTableException e) { + return getSessionCatalog().loadTable(ident, timestamp); + } + } + + @Override + public void invalidateTable(Identifier ident) { + // We do not need to check whether the table exists and whether + // it is an Iceberg table to reduce remote service requests. + icebergCatalog.invalidateTable(ident); + getSessionCatalog().invalidateTable(ident); + } + + @Override + public Table createTable( + Identifier ident, StructType schema, Transform[] partitions, Map properties) + throws TableAlreadyExistsException, NoSuchNamespaceException { + String provider = properties.get("provider"); + if (useIceberg(provider)) { + return icebergCatalog.createTable(ident, schema, partitions, properties); + } else { + // delegate to the session catalog + return getSessionCatalog().createTable(ident, schema, partitions, properties); + } + } + + @Override + public StagedTable stageCreate( + Identifier ident, StructType schema, Transform[] partitions, Map properties) + throws TableAlreadyExistsException, NoSuchNamespaceException { + String provider = properties.get("provider"); + TableCatalog catalog; + if (useIceberg(provider)) { + if (asStagingCatalog != null) { + return asStagingCatalog.stageCreate(ident, schema, partitions, properties); + } + catalog = icebergCatalog; + } else { + catalog = getSessionCatalog(); + } + + // create the table with the session catalog, then wrap it in a staged table that will delete to + // roll back + Table table = catalog.createTable(ident, schema, partitions, properties); + return new RollbackStagedTable(catalog, ident, table); + } + + @Override + public StagedTable stageReplace( + Identifier ident, StructType schema, Transform[] partitions, Map properties) + throws NoSuchNamespaceException, NoSuchTableException { + String provider = properties.get("provider"); + TableCatalog catalog; + if (useIceberg(provider)) { + if (asStagingCatalog != null) { + return asStagingCatalog.stageReplace(ident, schema, partitions, properties); + } + catalog = icebergCatalog; + } else { + catalog = getSessionCatalog(); + } + + // attempt to drop the table and fail if it doesn't exist + if (!catalog.dropTable(ident)) { + throw new NoSuchTableException(ident); + } + + try { + // create the table with the session catalog, then wrap it in a staged table that will delete + // to roll back + Table table = catalog.createTable(ident, schema, partitions, properties); + return new RollbackStagedTable(catalog, ident, table); + + } catch (TableAlreadyExistsException e) { + // the table was deleted, but now already exists again. retry the replace. + return stageReplace(ident, schema, partitions, properties); + } + } + + @Override + public StagedTable stageCreateOrReplace( + Identifier ident, StructType schema, Transform[] partitions, Map properties) + throws NoSuchNamespaceException { + String provider = properties.get("provider"); + TableCatalog catalog; + if (useIceberg(provider)) { + if (asStagingCatalog != null) { + return asStagingCatalog.stageCreateOrReplace(ident, schema, partitions, properties); + } + catalog = icebergCatalog; + } else { + catalog = getSessionCatalog(); + } + + // drop the table if it exists + catalog.dropTable(ident); + + try { + // create the table with the session catalog, then wrap it in a staged table that will delete + // to roll back + Table sessionCatalogTable = catalog.createTable(ident, schema, partitions, properties); + return new RollbackStagedTable(catalog, ident, sessionCatalogTable); + + } catch (TableAlreadyExistsException e) { + // the table was deleted, but now already exists again. retry the replace. + return stageCreateOrReplace(ident, schema, partitions, properties); + } + } + + @Override + public Table alterTable(Identifier ident, TableChange... changes) throws NoSuchTableException { + if (icebergCatalog.tableExists(ident)) { + return icebergCatalog.alterTable(ident, changes); + } else { + return getSessionCatalog().alterTable(ident, changes); + } + } + + @Override + public boolean dropTable(Identifier ident) { + // no need to check table existence to determine which catalog to use. if a table doesn't exist + // then both are + // required to return false. + return icebergCatalog.dropTable(ident) || getSessionCatalog().dropTable(ident); + } + + @Override + public boolean purgeTable(Identifier ident) { + // no need to check table existence to determine which catalog to use. if a table doesn't exist + // then both are + // required to return false. + return icebergCatalog.purgeTable(ident) || getSessionCatalog().purgeTable(ident); + } + + @Override + public void renameTable(Identifier from, Identifier to) + throws NoSuchTableException, TableAlreadyExistsException { + // rename is not supported by HadoopCatalog. to avoid UnsupportedOperationException for session + // catalog tables, + // check table existence first to ensure that the table belongs to the Iceberg catalog. + if (icebergCatalog.tableExists(from)) { + icebergCatalog.renameTable(from, to); + } else { + getSessionCatalog().renameTable(from, to); + } + } + + @Override + public final void initialize(String name, CaseInsensitiveStringMap options) { + super.initialize(name, options); + + if (options.containsKey(CatalogUtil.ICEBERG_CATALOG_TYPE) + && options + .get(CatalogUtil.ICEBERG_CATALOG_TYPE) + .equalsIgnoreCase(CatalogUtil.ICEBERG_CATALOG_TYPE_HIVE)) { + validateHmsUri(options.get(CatalogProperties.URI)); + } + + this.catalogName = name; + this.icebergCatalog = buildSparkCatalog(name, options); + if (icebergCatalog instanceof StagingTableCatalog) { + this.asStagingCatalog = (StagingTableCatalog) icebergCatalog; + } + + this.createParquetAsIceberg = options.getBoolean("parquet-enabled", createParquetAsIceberg); + this.createAvroAsIceberg = options.getBoolean("avro-enabled", createAvroAsIceberg); + this.createOrcAsIceberg = options.getBoolean("orc-enabled", createOrcAsIceberg); + } + + private void validateHmsUri(String catalogHmsUri) { + if (catalogHmsUri == null) { + return; + } + + Configuration conf = SparkSession.active().sessionState().newHadoopConf(); + String envHmsUri = conf.get(HiveConf.ConfVars.METASTOREURIS.varname, null); + if (envHmsUri == null) { + return; + } + + Preconditions.checkArgument( + catalogHmsUri.equals(envHmsUri), + "Inconsistent Hive metastore URIs: %s (Spark session) != %s (spark_catalog)", + envHmsUri, + catalogHmsUri); + } + + @Override + @SuppressWarnings("unchecked") + public void setDelegateCatalog(CatalogPlugin sparkSessionCatalog) { + if (sparkSessionCatalog instanceof TableCatalog + && sparkSessionCatalog instanceof FunctionCatalog + && sparkSessionCatalog instanceof SupportsNamespaces) { + this.sessionCatalog = (T) sparkSessionCatalog; + } else { + throw new IllegalArgumentException("Invalid session catalog: " + sparkSessionCatalog); + } + } + + @Override + public String name() { + return catalogName; + } + + private boolean useIceberg(String provider) { + if (provider == null || "iceberg".equalsIgnoreCase(provider)) { + return true; + } else if (createParquetAsIceberg && "parquet".equalsIgnoreCase(provider)) { + return true; + } else if (createAvroAsIceberg && "avro".equalsIgnoreCase(provider)) { + return true; + } else if (createOrcAsIceberg && "orc".equalsIgnoreCase(provider)) { + return true; + } + + return false; + } + + private T getSessionCatalog() { + Preconditions.checkNotNull( + sessionCatalog, + "Delegated SessionCatalog is missing. " + + "Please make sure your are replacing Spark's default catalog, named 'spark_catalog'."); + return sessionCatalog; + } + + @Override + public Catalog icebergCatalog() { + Preconditions.checkArgument( + icebergCatalog instanceof HasIcebergCatalog, + "Cannot return underlying Iceberg Catalog, wrapped catalog does not contain an Iceberg Catalog"); + return ((HasIcebergCatalog) icebergCatalog).icebergCatalog(); + } + + @Override + public UnboundFunction loadFunction(Identifier ident) throws NoSuchFunctionException { + try { + return super.loadFunction(ident); + } catch (NoSuchFunctionException e) { + return getSessionCatalog().loadFunction(ident); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkStructLike.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkStructLike.java new file mode 100644 index 000000000000..77cfa0f34c63 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkStructLike.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import org.apache.iceberg.StructLike; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Row; + +public class SparkStructLike implements StructLike { + + private final Types.StructType type; + private Row wrapped; + + public SparkStructLike(Types.StructType type) { + this.type = type; + } + + public SparkStructLike wrap(Row row) { + this.wrapped = row; + return this; + } + + @Override + public int size() { + return type.fields().size(); + } + + @Override + public T get(int pos, Class javaClass) { + Types.NestedField field = type.fields().get(pos); + return javaClass.cast(SparkValueConverter.convert(field.type(), wrapped.get(pos))); + } + + @Override + public void set(int pos, T value) { + throw new UnsupportedOperationException("Not implemented: set"); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkTableCache.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkTableCache.java new file mode 100644 index 000000000000..6218423db491 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkTableCache.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Map; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; + +public class SparkTableCache { + + private static final SparkTableCache INSTANCE = new SparkTableCache(); + + private final Map cache = Maps.newConcurrentMap(); + + public static SparkTableCache get() { + return INSTANCE; + } + + public int size() { + return cache.size(); + } + + public void add(String key, Table table) { + cache.put(key, table); + } + + public boolean contains(String key) { + return cache.containsKey(key); + } + + public Table get(String key) { + return cache.get(key); + } + + public Table remove(String key) { + return cache.remove(key); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java new file mode 100644 index 000000000000..c44969c49e39 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java @@ -0,0 +1,974 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.apache.spark.sql.functions.col; + +import java.io.IOException; +import java.io.Serializable; +import java.net.URI; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ExecutorService; +import java.util.stream.Collectors; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.AppendFiles; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.ManifestFiles; +import org.apache.iceberg.ManifestWriter; +import org.apache.iceberg.MetadataTableType; +import org.apache.iceberg.MetadataTableUtils; +import org.apache.iceberg.MetricsConfig; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.data.TableMigrationUtil; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.hadoop.HadoopFileIO; +import org.apache.iceberg.hadoop.SerializableConfiguration; +import org.apache.iceberg.hadoop.Util; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.mapping.NameMapping; +import org.apache.iceberg.mapping.NameMappingParser; +import org.apache.iceberg.relocated.com.google.common.base.Joiner; +import org.apache.iceberg.relocated.com.google.common.base.MoreObjects; +import org.apache.iceberg.relocated.com.google.common.base.Objects; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.SparkTable; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.Tasks; +import org.apache.iceberg.util.ThreadPools; +import org.apache.spark.TaskContext; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.MapFunction; +import org.apache.spark.api.java.function.MapPartitionsFunction; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.TableIdentifier; +import org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute; +import org.apache.spark.sql.catalyst.catalog.CatalogTable; +import org.apache.spark.sql.catalyst.catalog.CatalogTablePartition; +import org.apache.spark.sql.catalyst.catalog.SessionCatalog; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.NamedExpression; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import scala.Function2; +import scala.Option; +import scala.Some; +import scala.Tuple2; +import scala.collection.JavaConverters; +import scala.collection.immutable.Map$; +import scala.collection.immutable.Seq; +import scala.collection.mutable.Builder; +import scala.runtime.AbstractPartialFunction; + +/** + * Java version of the original SparkTableUtil.scala + * https://github.com/apache/iceberg/blob/apache-iceberg-0.8.0-incubating/spark/src/main/scala/org/apache/iceberg/spark/SparkTableUtil.scala + */ +public class SparkTableUtil { + + private static final String DUPLICATE_FILE_MESSAGE = + "Cannot complete import because data files " + + "to be imported already exist within the target table: %s. " + + "This is disabled by default as Iceberg is not designed for multiple references to the same file" + + " within the same table. If you are sure, you may set 'check_duplicate_files' to false to force the import."; + + private SparkTableUtil() {} + + /** + * Returns a DataFrame with a row for each partition in the table. + * + *

The DataFrame has 3 columns, partition key (a=1/b=2), partition location, and format (avro + * or parquet). + * + * @param spark a Spark session + * @param table a table name and (optional) database + * @return a DataFrame of the table's partitions + */ + public static Dataset partitionDF(SparkSession spark, String table) { + List partitions = getPartitions(spark, table); + return spark + .createDataFrame(partitions, SparkPartition.class) + .toDF("partition", "uri", "format"); + } + + /** + * Returns a DataFrame with a row for each partition that matches the specified 'expression'. + * + * @param spark a Spark session. + * @param table name of the table. + * @param expression The expression whose matching partitions are returned. + * @return a DataFrame of the table partitions. + */ + public static Dataset partitionDFByFilter( + SparkSession spark, String table, String expression) { + List partitions = getPartitionsByFilter(spark, table, expression); + return spark + .createDataFrame(partitions, SparkPartition.class) + .toDF("partition", "uri", "format"); + } + + /** + * Returns all partitions in the table. + * + * @param spark a Spark session + * @param table a table name and (optional) database + * @return all table's partitions + */ + public static List getPartitions(SparkSession spark, String table) { + try { + TableIdentifier tableIdent = spark.sessionState().sqlParser().parseTableIdentifier(table); + return getPartitions(spark, tableIdent, null); + } catch (ParseException e) { + throw SparkExceptionUtil.toUncheckedException( + e, "Unable to parse table identifier: %s", table); + } + } + + /** + * Returns all partitions in the table. + * + * @param spark a Spark session + * @param tableIdent a table identifier + * @param partitionFilter partition filter, or null if no filter + * @return all table's partitions + */ + public static List getPartitions( + SparkSession spark, TableIdentifier tableIdent, Map partitionFilter) { + try { + SessionCatalog catalog = spark.sessionState().catalog(); + CatalogTable catalogTable = catalog.getTableMetadata(tableIdent); + + Option> scalaPartitionFilter; + if (partitionFilter != null && !partitionFilter.isEmpty()) { + Builder, scala.collection.immutable.Map> builder = + Map$.MODULE$.newBuilder(); + partitionFilter.forEach((key, value) -> builder.$plus$eq(Tuple2.apply(key, value))); + scalaPartitionFilter = Option.apply(builder.result()); + } else { + scalaPartitionFilter = Option.empty(); + } + Seq partitions = + catalog.listPartitions(tableIdent, scalaPartitionFilter).toIndexedSeq(); + return JavaConverters.seqAsJavaListConverter(partitions).asJava().stream() + .map(catalogPartition -> toSparkPartition(catalogPartition, catalogTable)) + .collect(Collectors.toList()); + } catch (NoSuchDatabaseException e) { + throw SparkExceptionUtil.toUncheckedException( + e, "Unknown table: %s. Database not found in catalog.", tableIdent); + } catch (NoSuchTableException e) { + throw SparkExceptionUtil.toUncheckedException( + e, "Unknown table: %s. Table not found in catalog.", tableIdent); + } + } + + /** + * Returns partitions that match the specified 'predicate'. + * + * @param spark a Spark session + * @param table a table name and (optional) database + * @param predicate a predicate on partition columns + * @return matching table's partitions + */ + public static List getPartitionsByFilter( + SparkSession spark, String table, String predicate) { + TableIdentifier tableIdent; + try { + tableIdent = spark.sessionState().sqlParser().parseTableIdentifier(table); + } catch (ParseException e) { + throw SparkExceptionUtil.toUncheckedException( + e, "Unable to parse the table identifier: %s", table); + } + + Expression unresolvedPredicateExpr; + try { + unresolvedPredicateExpr = spark.sessionState().sqlParser().parseExpression(predicate); + } catch (ParseException e) { + throw SparkExceptionUtil.toUncheckedException( + e, "Unable to parse the predicate expression: %s", predicate); + } + + Expression resolvedPredicateExpr = resolveAttrs(spark, table, unresolvedPredicateExpr); + return getPartitionsByFilter(spark, tableIdent, resolvedPredicateExpr); + } + + /** + * Returns partitions that match the specified 'predicate'. + * + * @param spark a Spark session + * @param tableIdent a table identifier + * @param predicateExpr a predicate expression on partition columns + * @return matching table's partitions + */ + public static List getPartitionsByFilter( + SparkSession spark, TableIdentifier tableIdent, Expression predicateExpr) { + try { + SessionCatalog catalog = spark.sessionState().catalog(); + CatalogTable catalogTable = catalog.getTableMetadata(tableIdent); + + Expression resolvedPredicateExpr; + if (!predicateExpr.resolved()) { + resolvedPredicateExpr = resolveAttrs(spark, tableIdent.quotedString(), predicateExpr); + } else { + resolvedPredicateExpr = predicateExpr; + } + Seq predicates = + JavaConverters.collectionAsScalaIterableConverter(ImmutableList.of(resolvedPredicateExpr)) + .asScala() + .toIndexedSeq(); + + Seq partitions = + catalog.listPartitionsByFilter(tableIdent, predicates).toIndexedSeq(); + + return JavaConverters.seqAsJavaListConverter(partitions).asJava().stream() + .map(catalogPartition -> toSparkPartition(catalogPartition, catalogTable)) + .collect(Collectors.toList()); + } catch (NoSuchDatabaseException e) { + throw SparkExceptionUtil.toUncheckedException( + e, "Unknown table: %s. Database not found in catalog.", tableIdent); + } catch (NoSuchTableException e) { + throw SparkExceptionUtil.toUncheckedException( + e, "Unknown table: %s. Table not found in catalog.", tableIdent); + } + } + + private static List listPartition( + SparkPartition partition, + PartitionSpec spec, + SerializableConfiguration conf, + MetricsConfig metricsConfig, + NameMapping mapping, + int parallelism) { + return TableMigrationUtil.listPartition( + partition.values, + partition.uri, + partition.format, + spec, + conf.get(), + metricsConfig, + mapping, + parallelism); + } + + private static List listPartition( + SparkPartition partition, + PartitionSpec spec, + SerializableConfiguration conf, + MetricsConfig metricsConfig, + NameMapping mapping, + ExecutorService service) { + return TableMigrationUtil.listPartition( + partition.values, + partition.uri, + partition.format, + spec, + conf.get(), + metricsConfig, + mapping, + service); + } + + private static SparkPartition toSparkPartition( + CatalogTablePartition partition, CatalogTable table) { + Option locationUri = partition.storage().locationUri(); + Option serde = partition.storage().serde(); + + Preconditions.checkArgument(locationUri.nonEmpty(), "Partition URI should be defined"); + Preconditions.checkArgument( + serde.nonEmpty() || table.provider().nonEmpty(), "Partition format should be defined"); + + String uri = Util.uriToString(locationUri.get()); + String format = serde.nonEmpty() ? serde.get() : table.provider().get(); + + Map partitionSpec = + JavaConverters.mapAsJavaMapConverter(partition.spec()).asJava(); + return new SparkPartition(partitionSpec, uri, format); + } + + private static Expression resolveAttrs(SparkSession spark, String table, Expression expr) { + Function2 resolver = spark.sessionState().analyzer().resolver(); + LogicalPlan plan = spark.table(table).queryExecution().analyzed(); + return expr.transform( + new AbstractPartialFunction() { + @Override + public Expression apply(Expression attr) { + UnresolvedAttribute unresolvedAttribute = (UnresolvedAttribute) attr; + Option namedExpressionOption = + plan.resolve(unresolvedAttribute.nameParts(), resolver); + if (namedExpressionOption.isDefined()) { + return (Expression) namedExpressionOption.get(); + } else { + throw new IllegalArgumentException( + String.format("Could not resolve %s using columns: %s", attr, plan.output())); + } + } + + @Override + public boolean isDefinedAt(Expression attr) { + return attr instanceof UnresolvedAttribute; + } + }); + } + + private static Iterator buildManifest( + SerializableConfiguration conf, + PartitionSpec spec, + String basePath, + Iterator> fileTuples) { + if (fileTuples.hasNext()) { + FileIO io = new HadoopFileIO(conf.get()); + TaskContext ctx = TaskContext.get(); + String suffix = + String.format( + "stage-%d-task-%d-manifest-%s", + ctx.stageId(), ctx.taskAttemptId(), UUID.randomUUID()); + Path location = new Path(basePath, suffix); + String outputPath = FileFormat.AVRO.addExtension(location.toString()); + OutputFile outputFile = io.newOutputFile(outputPath); + ManifestWriter writer = ManifestFiles.write(spec, outputFile); + + try (ManifestWriter writerRef = writer) { + fileTuples.forEachRemaining(fileTuple -> writerRef.add(fileTuple._2)); + } catch (IOException e) { + throw SparkExceptionUtil.toUncheckedException( + e, "Unable to close the manifest writer: %s", outputPath); + } + + ManifestFile manifestFile = writer.toManifestFile(); + return ImmutableList.of(manifestFile).iterator(); + } else { + return Collections.emptyIterator(); + } + } + + /** + * Import files from an existing Spark table to an Iceberg table. + * + *

The import uses the Spark session to get table metadata. It assumes no operation is going on + * the original and target table and thus is not thread-safe. + * + * @param spark a Spark session + * @param sourceTableIdent an identifier of the source Spark table + * @param targetTable an Iceberg table where to import the data + * @param stagingDir a staging directory to store temporary manifest files + * @param partitionFilter only import partitions whose values match those in the map, can be + * partially defined + * @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file + */ + public static void importSparkTable( + SparkSession spark, + TableIdentifier sourceTableIdent, + Table targetTable, + String stagingDir, + Map partitionFilter, + boolean checkDuplicateFiles) { + importSparkTable( + spark, sourceTableIdent, targetTable, stagingDir, partitionFilter, checkDuplicateFiles, 1); + } + + /** + * Import files from an existing Spark table to an Iceberg table. + * + *

The import uses the Spark session to get table metadata. It assumes no operation is going on + * the original and target table and thus is not thread-safe. + * + * @param spark a Spark session + * @param sourceTableIdent an identifier of the source Spark table + * @param targetTable an Iceberg table where to import the data + * @param stagingDir a staging directory to store temporary manifest files + * @param parallelism number of threads to use for file reading + */ + public static void importSparkTable( + SparkSession spark, + TableIdentifier sourceTableIdent, + Table targetTable, + String stagingDir, + int parallelism) { + importSparkTable( + spark, + sourceTableIdent, + targetTable, + stagingDir, + TableMigrationUtil.migrationService(parallelism)); + } + + /** + * Import files from an existing Spark table to an Iceberg table. + * + *

The import uses the Spark session to get table metadata. It assumes no operation is going on + * the original and target table and thus is not thread-safe. + * + * @param spark a Spark session + * @param sourceTableIdent an identifier of the source Spark table + * @param targetTable an Iceberg table where to import the data + * @param stagingDir a staging directory to store temporary manifest files + * @param service executor service to use for file reading. If null, file reading will be + * performed on the current thread. * If non-null, the provided ExecutorService will be + * shutdown within this method after file reading is complete. + */ + public static void importSparkTable( + SparkSession spark, + TableIdentifier sourceTableIdent, + Table targetTable, + String stagingDir, + ExecutorService service) { + importSparkTable( + spark, sourceTableIdent, targetTable, stagingDir, Collections.emptyMap(), false, service); + } + + /** + * Import files from an existing Spark table to an Iceberg table. + * + *

The import uses the Spark session to get table metadata. It assumes no operation is going on + * the original and target table and thus is not thread-safe. + * + * @param spark a Spark session + * @param sourceTableIdent an identifier of the source Spark table + * @param targetTable an Iceberg table where to import the data + * @param stagingDir a staging directory to store temporary manifest files + * @param partitionFilter only import partitions whose values match those in the map, can be + * partially defined + * @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file + * @param parallelism number of threads to use for file reading + */ + public static void importSparkTable( + SparkSession spark, + TableIdentifier sourceTableIdent, + Table targetTable, + String stagingDir, + Map partitionFilter, + boolean checkDuplicateFiles, + int parallelism) { + importSparkTable( + spark, + sourceTableIdent, + targetTable, + stagingDir, + partitionFilter, + checkDuplicateFiles, + TableMigrationUtil.migrationService(parallelism)); + } + + /** + * Import files from an existing Spark table to an Iceberg table. + * + *

The import uses the Spark session to get table metadata. It assumes no operation is going on + * the original and target table and thus is not thread-safe. + * + * @param spark a Spark session + * @param sourceTableIdent an identifier of the source Spark table + * @param targetTable an Iceberg table where to import the data + * @param stagingDir a staging directory to store temporary manifest files + * @param partitionFilter only import partitions whose values match those in the map, can be + * partially defined + * @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file + * @param service executor service to use for file reading. If null, file reading will be + * performed on the current thread. If non-null, the provided ExecutorService will be shutdown + * within this method after file reading is complete. + */ + public static void importSparkTable( + SparkSession spark, + TableIdentifier sourceTableIdent, + Table targetTable, + String stagingDir, + Map partitionFilter, + boolean checkDuplicateFiles, + ExecutorService service) { + SessionCatalog catalog = spark.sessionState().catalog(); + + String db = + sourceTableIdent.database().nonEmpty() + ? sourceTableIdent.database().get() + : catalog.getCurrentDatabase(); + TableIdentifier sourceTableIdentWithDB = + new TableIdentifier(sourceTableIdent.table(), Some.apply(db)); + + if (!catalog.tableExists(sourceTableIdentWithDB)) { + throw new org.apache.iceberg.exceptions.NoSuchTableException( + "Table %s does not exist", sourceTableIdentWithDB); + } + + try { + PartitionSpec spec = + SparkSchemaUtil.specForTable(spark, sourceTableIdentWithDB.unquotedString()); + + if (Objects.equal(spec, PartitionSpec.unpartitioned())) { + importUnpartitionedSparkTable( + spark, sourceTableIdentWithDB, targetTable, checkDuplicateFiles, service); + } else { + List sourceTablePartitions = + getPartitions(spark, sourceTableIdent, partitionFilter); + if (sourceTablePartitions.isEmpty()) { + targetTable.newAppend().commit(); + } else { + importSparkPartitions( + spark, + sourceTablePartitions, + targetTable, + spec, + stagingDir, + checkDuplicateFiles, + service); + } + } + } catch (AnalysisException e) { + throw SparkExceptionUtil.toUncheckedException( + e, "Unable to get partition spec for table: %s", sourceTableIdentWithDB); + } + } + + /** + * Import files from an existing Spark table to an Iceberg table. + * + *

The import uses the Spark session to get table metadata. It assumes no operation is going on + * the original and target table and thus is not thread-safe. + * + * @param spark a Spark session + * @param sourceTableIdent an identifier of the source Spark table + * @param targetTable an Iceberg table where to import the data + * @param stagingDir a staging directory to store temporary manifest files + * @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file + */ + public static void importSparkTable( + SparkSession spark, + TableIdentifier sourceTableIdent, + Table targetTable, + String stagingDir, + boolean checkDuplicateFiles) { + importSparkTable( + spark, + sourceTableIdent, + targetTable, + stagingDir, + Collections.emptyMap(), + checkDuplicateFiles, + 1); + } + + /** + * Import files from an existing Spark table to an Iceberg table. + * + *

The import uses the Spark session to get table metadata. It assumes no operation is going on + * the original and target table and thus is not thread-safe. + * + * @param spark a Spark session + * @param sourceTableIdent an identifier of the source Spark table + * @param targetTable an Iceberg table where to import the data + * @param stagingDir a staging directory to store temporary manifest files + */ + public static void importSparkTable( + SparkSession spark, TableIdentifier sourceTableIdent, Table targetTable, String stagingDir) { + importSparkTable( + spark, sourceTableIdent, targetTable, stagingDir, Collections.emptyMap(), false, 1); + } + + private static void importUnpartitionedSparkTable( + SparkSession spark, + TableIdentifier sourceTableIdent, + Table targetTable, + boolean checkDuplicateFiles, + ExecutorService service) { + try { + CatalogTable sourceTable = spark.sessionState().catalog().getTableMetadata(sourceTableIdent); + Option format = + sourceTable.storage().serde().nonEmpty() + ? sourceTable.storage().serde() + : sourceTable.provider(); + Preconditions.checkArgument(format.nonEmpty(), "Could not determine table format"); + + Map partition = Collections.emptyMap(); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Configuration conf = spark.sessionState().newHadoopConf(); + MetricsConfig metricsConfig = MetricsConfig.forTable(targetTable); + String nameMappingString = targetTable.properties().get(TableProperties.DEFAULT_NAME_MAPPING); + NameMapping nameMapping = + nameMappingString != null ? NameMappingParser.fromJson(nameMappingString) : null; + + List files = + TableMigrationUtil.listPartition( + partition, + Util.uriToString(sourceTable.location()), + format.get(), + spec, + conf, + metricsConfig, + nameMapping, + service); + + if (checkDuplicateFiles) { + Dataset importedFiles = + spark + .createDataset(Lists.transform(files, f -> f.path().toString()), Encoders.STRING()) + .toDF("file_path"); + Dataset existingFiles = + loadMetadataTable(spark, targetTable, MetadataTableType.ENTRIES).filter("status != 2"); + Column joinCond = + existingFiles.col("data_file.file_path").equalTo(importedFiles.col("file_path")); + Dataset duplicates = + importedFiles.join(existingFiles, joinCond).select("file_path").as(Encoders.STRING()); + Preconditions.checkState( + duplicates.isEmpty(), + String.format( + DUPLICATE_FILE_MESSAGE, Joiner.on(",").join((String[]) duplicates.take(10)))); + } + + AppendFiles append = targetTable.newAppend(); + files.forEach(append::appendFile); + append.commit(); + } catch (NoSuchDatabaseException e) { + throw SparkExceptionUtil.toUncheckedException( + e, "Unknown table: %s. Database not found in catalog.", sourceTableIdent); + } catch (NoSuchTableException e) { + throw SparkExceptionUtil.toUncheckedException( + e, "Unknown table: %s. Table not found in catalog.", sourceTableIdent); + } + } + + /** + * Import files from given partitions to an Iceberg table. + * + * @param spark a Spark session + * @param partitions partitions to import + * @param targetTable an Iceberg table where to import the data + * @param spec a partition spec + * @param stagingDir a staging directory to store temporary manifest files + * @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file + */ + public static void importSparkPartitions( + SparkSession spark, + List partitions, + Table targetTable, + PartitionSpec spec, + String stagingDir, + boolean checkDuplicateFiles) { + importSparkPartitions(spark, partitions, targetTable, spec, stagingDir, checkDuplicateFiles, 1); + } + + /** + * Import files from given partitions to an Iceberg table. + * + * @param spark a Spark session + * @param partitions partitions to import + * @param targetTable an Iceberg table where to import the data + * @param spec a partition spec + * @param stagingDir a staging directory to store temporary manifest files + * @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file + * @param parallelism number of threads to use for file reading + */ + public static void importSparkPartitions( + SparkSession spark, + List partitions, + Table targetTable, + PartitionSpec spec, + String stagingDir, + boolean checkDuplicateFiles, + int parallelism) { + importSparkPartitions( + spark, + partitions, + targetTable, + spec, + stagingDir, + checkDuplicateFiles, + TableMigrationUtil.migrationService(parallelism)); + } + + /** + * Import files from given partitions to an Iceberg table. + * + * @param spark a Spark session + * @param partitions partitions to import + * @param targetTable an Iceberg table where to import the data + * @param spec a partition spec + * @param stagingDir a staging directory to store temporary manifest files + * @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file + * @param service executor service to use for file reading. If null, file reading will be + * performed on the current thread. If non-null, the provided ExecutorService will be shutdown + * within this method after file reading is complete. + */ + public static void importSparkPartitions( + SparkSession spark, + List partitions, + Table targetTable, + PartitionSpec spec, + String stagingDir, + boolean checkDuplicateFiles, + ExecutorService service) { + Configuration conf = spark.sessionState().newHadoopConf(); + SerializableConfiguration serializableConf = new SerializableConfiguration(conf); + int listingParallelism = + Math.min( + partitions.size(), spark.sessionState().conf().parallelPartitionDiscoveryParallelism()); + int numShufflePartitions = spark.sessionState().conf().numShufflePartitions(); + MetricsConfig metricsConfig = MetricsConfig.fromProperties(targetTable.properties()); + String nameMappingString = targetTable.properties().get(TableProperties.DEFAULT_NAME_MAPPING); + NameMapping nameMapping = + nameMappingString != null ? NameMappingParser.fromJson(nameMappingString) : null; + + JavaSparkContext sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + JavaRDD partitionRDD = sparkContext.parallelize(partitions, listingParallelism); + + Dataset partitionDS = + spark.createDataset(partitionRDD.rdd(), Encoders.javaSerialization(SparkPartition.class)); + + Dataset filesToImport = + partitionDS.flatMap( + (FlatMapFunction) + sparkPartition -> + listPartition( + sparkPartition, + spec, + serializableConf, + metricsConfig, + nameMapping, + service) + .iterator(), + Encoders.javaSerialization(DataFile.class)); + + if (checkDuplicateFiles) { + Dataset importedFiles = + filesToImport + .map((MapFunction) f -> f.path().toString(), Encoders.STRING()) + .toDF("file_path"); + Dataset existingFiles = + loadMetadataTable(spark, targetTable, MetadataTableType.ENTRIES).filter("status != 2"); + Column joinCond = + existingFiles.col("data_file.file_path").equalTo(importedFiles.col("file_path")); + Dataset duplicates = + importedFiles.join(existingFiles, joinCond).select("file_path").as(Encoders.STRING()); + Preconditions.checkState( + duplicates.isEmpty(), + String.format( + DUPLICATE_FILE_MESSAGE, Joiner.on(",").join((String[]) duplicates.take(10)))); + } + + List manifests = + filesToImport + .repartition(numShufflePartitions) + .map( + (MapFunction>) + file -> Tuple2.apply(file.path().toString(), file), + Encoders.tuple(Encoders.STRING(), Encoders.javaSerialization(DataFile.class))) + .orderBy(col("_1")) + .mapPartitions( + (MapPartitionsFunction, ManifestFile>) + fileTuple -> buildManifest(serializableConf, spec, stagingDir, fileTuple), + Encoders.javaSerialization(ManifestFile.class)) + .collectAsList(); + + try { + TableOperations ops = ((HasTableOperations) targetTable).operations(); + int formatVersion = ops.current().formatVersion(); + boolean snapshotIdInheritanceEnabled = + PropertyUtil.propertyAsBoolean( + targetTable.properties(), + TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, + TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED_DEFAULT); + + AppendFiles append = targetTable.newAppend(); + manifests.forEach(append::appendManifest); + append.commit(); + + if (formatVersion == 1 && !snapshotIdInheritanceEnabled) { + // delete original manifests as they were rewritten before the commit + deleteManifests(targetTable.io(), manifests); + } + } catch (Throwable e) { + deleteManifests(targetTable.io(), manifests); + throw e; + } + } + + /** + * Import files from given partitions to an Iceberg table. + * + * @param spark a Spark session + * @param partitions partitions to import + * @param targetTable an Iceberg table where to import the data + * @param spec a partition spec + * @param stagingDir a staging directory to store temporary manifest files + */ + public static void importSparkPartitions( + SparkSession spark, + List partitions, + Table targetTable, + PartitionSpec spec, + String stagingDir) { + importSparkPartitions(spark, partitions, targetTable, spec, stagingDir, false, 1); + } + + public static List filterPartitions( + List partitions, Map partitionFilter) { + if (partitionFilter.isEmpty()) { + return partitions; + } else { + return partitions.stream() + .filter(p -> p.getValues().entrySet().containsAll(partitionFilter.entrySet())) + .collect(Collectors.toList()); + } + } + + private static void deleteManifests(FileIO io, List manifests) { + Tasks.foreach(manifests) + .executeWith(ThreadPools.getWorkerPool()) + .noRetry() + .suppressFailureWhenFinished() + .run(item -> io.deleteFile(item.path())); + } + + public static Dataset loadTable(SparkSession spark, Table table, long snapshotId) { + SparkTable sparkTable = new SparkTable(table, snapshotId, false); + DataSourceV2Relation relation = createRelation(sparkTable, ImmutableMap.of()); + return Dataset.ofRows(spark, relation); + } + + public static Dataset loadMetadataTable( + SparkSession spark, Table table, MetadataTableType type) { + return loadMetadataTable(spark, table, type, ImmutableMap.of()); + } + + public static Dataset loadMetadataTable( + SparkSession spark, Table table, MetadataTableType type, Map extraOptions) { + Table metadataTable = MetadataTableUtils.createMetadataTableInstance(table, type); + SparkTable sparkMetadataTable = new SparkTable(metadataTable, false); + DataSourceV2Relation relation = createRelation(sparkMetadataTable, extraOptions); + return Dataset.ofRows(spark, relation); + } + + private static DataSourceV2Relation createRelation( + SparkTable sparkTable, Map extraOptions) { + CaseInsensitiveStringMap options = new CaseInsensitiveStringMap(extraOptions); + return DataSourceV2Relation.create(sparkTable, Option.empty(), Option.empty(), options); + } + + /** + * Determine the write branch. + * + *

Validate wap config and determine the write branch. + * + * @param spark a Spark Session + * @param branch write branch if there is no WAP branch configured + * @return branch for write operation + */ + public static String determineWriteBranch(SparkSession spark, String branch) { + String wapId = spark.conf().get(SparkSQLProperties.WAP_ID, null); + String wapBranch = spark.conf().get(SparkSQLProperties.WAP_BRANCH, null); + ValidationException.check( + wapId == null || wapBranch == null, + "Cannot set both WAP ID and branch, but got ID [%s] and branch [%s]", + wapId, + wapBranch); + + if (wapBranch != null) { + ValidationException.check( + branch == null, + "Cannot write to both branch and WAP branch, but got branch [%s] and WAP branch [%s]", + branch, + wapBranch); + + return wapBranch; + } + return branch; + } + + public static boolean wapEnabled(Table table) { + return PropertyUtil.propertyAsBoolean( + table.properties(), + TableProperties.WRITE_AUDIT_PUBLISH_ENABLED, + Boolean.parseBoolean(TableProperties.WRITE_AUDIT_PUBLISH_ENABLED_DEFAULT)); + } + + /** Class representing a table partition. */ + public static class SparkPartition implements Serializable { + private final Map values; + private final String uri; + private final String format; + + public SparkPartition(Map values, String uri, String format) { + this.values = Maps.newHashMap(values); + this.uri = uri; + this.format = format; + } + + public Map getValues() { + return values; + } + + public String getUri() { + return uri; + } + + public String getFormat() { + return format; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("values", values) + .add("uri", uri) + .add("format", format) + .toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SparkPartition that = (SparkPartition) o; + return Objects.equal(values, that.values) + && Objects.equal(uri, that.uri) + && Objects.equal(format, that.format); + } + + @Override + public int hashCode() { + return Objects.hashCode(values, uri, format); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkTypeToType.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkTypeToType.java new file mode 100644 index 000000000000..8beaefc5cc8f --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkTypeToType.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.List; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.BooleanType; +import org.apache.spark.sql.types.ByteType; +import org.apache.spark.sql.types.CharType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.FloatType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.TimestampNTZType; +import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.sql.types.VarcharType; + +class SparkTypeToType extends SparkTypeVisitor { + private final StructType root; + private int nextId = 0; + + SparkTypeToType() { + this.root = null; + } + + SparkTypeToType(StructType root) { + this.root = root; + // the root struct's fields use the first ids + this.nextId = root.fields().length; + } + + private int getNextId() { + int next = nextId; + nextId += 1; + return next; + } + + @Override + @SuppressWarnings("ReferenceEquality") + public Type struct(StructType struct, List types) { + StructField[] fields = struct.fields(); + List newFields = Lists.newArrayListWithExpectedSize(fields.length); + boolean isRoot = root == struct; + for (int i = 0; i < fields.length; i += 1) { + StructField field = fields[i]; + Type type = types.get(i); + + int id; + if (isRoot) { + // for new conversions, use ordinals for ids in the root struct + id = i; + } else { + id = getNextId(); + } + + String doc = field.getComment().isDefined() ? field.getComment().get() : null; + + if (field.nullable()) { + newFields.add(Types.NestedField.optional(id, field.name(), type, doc)); + } else { + newFields.add(Types.NestedField.required(id, field.name(), type, doc)); + } + } + + return Types.StructType.of(newFields); + } + + @Override + public Type field(StructField field, Type typeResult) { + return typeResult; + } + + @Override + public Type array(ArrayType array, Type elementType) { + if (array.containsNull()) { + return Types.ListType.ofOptional(getNextId(), elementType); + } else { + return Types.ListType.ofRequired(getNextId(), elementType); + } + } + + @Override + public Type map(MapType map, Type keyType, Type valueType) { + if (map.valueContainsNull()) { + return Types.MapType.ofOptional(getNextId(), getNextId(), keyType, valueType); + } else { + return Types.MapType.ofRequired(getNextId(), getNextId(), keyType, valueType); + } + } + + @SuppressWarnings("checkstyle:CyclomaticComplexity") + @Override + public Type atomic(DataType atomic) { + if (atomic instanceof BooleanType) { + return Types.BooleanType.get(); + + } else if (atomic instanceof IntegerType + || atomic instanceof ShortType + || atomic instanceof ByteType) { + return Types.IntegerType.get(); + + } else if (atomic instanceof LongType) { + return Types.LongType.get(); + + } else if (atomic instanceof FloatType) { + return Types.FloatType.get(); + + } else if (atomic instanceof DoubleType) { + return Types.DoubleType.get(); + + } else if (atomic instanceof StringType + || atomic instanceof CharType + || atomic instanceof VarcharType) { + return Types.StringType.get(); + + } else if (atomic instanceof DateType) { + return Types.DateType.get(); + + } else if (atomic instanceof TimestampType) { + return Types.TimestampType.withZone(); + + } else if (atomic instanceof TimestampNTZType) { + return Types.TimestampType.withoutZone(); + + } else if (atomic instanceof DecimalType) { + return Types.DecimalType.of( + ((DecimalType) atomic).precision(), ((DecimalType) atomic).scale()); + } else if (atomic instanceof BinaryType) { + return Types.BinaryType.get(); + } + + throw new UnsupportedOperationException("Not a supported type: " + atomic.catalogString()); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkTypeVisitor.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkTypeVisitor.java new file mode 100644 index 000000000000..1ef694263fa4 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkTypeVisitor.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.List; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.UserDefinedType; + +class SparkTypeVisitor { + static T visit(DataType type, SparkTypeVisitor visitor) { + if (type instanceof StructType) { + StructField[] fields = ((StructType) type).fields(); + List fieldResults = Lists.newArrayListWithExpectedSize(fields.length); + + for (StructField field : fields) { + fieldResults.add(visitor.field(field, visit(field.dataType(), visitor))); + } + + return visitor.struct((StructType) type, fieldResults); + + } else if (type instanceof MapType) { + return visitor.map( + (MapType) type, + visit(((MapType) type).keyType(), visitor), + visit(((MapType) type).valueType(), visitor)); + + } else if (type instanceof ArrayType) { + return visitor.array((ArrayType) type, visit(((ArrayType) type).elementType(), visitor)); + + } else if (type instanceof UserDefinedType) { + throw new UnsupportedOperationException("User-defined types are not supported"); + + } else { + return visitor.atomic(type); + } + } + + public T struct(StructType struct, List fieldResults) { + return null; + } + + public T field(StructField field, T typeResult) { + return null; + } + + public T array(ArrayType array, T elementResult) { + return null; + } + + public T map(MapType map, T keyResult, T valueResult) { + return null; + } + + public T atomic(DataType atomic) { + return null; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java new file mode 100644 index 000000000000..de06cceb2677 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.sql.Date; +import java.sql.Timestamp; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Collectors; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionField; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.relocated.com.google.common.base.Joiner; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.transforms.Transform; +import org.apache.iceberg.transforms.UnknownTransform; +import org.apache.iceberg.util.Pair; +import org.apache.spark.SparkEnv; +import org.apache.spark.scheduler.ExecutorCacheTaskLocation; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.expressions.BoundReference; +import org.apache.spark.sql.catalyst.expressions.EqualTo; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.Literal; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.BlockManagerId; +import org.apache.spark.storage.BlockManagerMaster; +import org.joda.time.DateTime; +import scala.collection.JavaConverters; +import scala.collection.Seq; + +public class SparkUtil { + private static final String SPARK_CATALOG_CONF_PREFIX = "spark.sql.catalog"; + // Format string used as the prefix for Spark configuration keys to override Hadoop configuration + // values for Iceberg tables from a given catalog. These keys can be specified as + // `spark.sql.catalog.$catalogName.hadoop.*`, similar to using `spark.hadoop.*` to override + // Hadoop configurations globally for a given Spark session. + private static final String SPARK_CATALOG_HADOOP_CONF_OVERRIDE_FMT_STR = + SPARK_CATALOG_CONF_PREFIX + ".%s.hadoop."; + + private static final Joiner DOT = Joiner.on("."); + + private SparkUtil() {} + + /** + * Check whether the partition transforms in a spec can be used to write data. + * + * @param spec a PartitionSpec + * @throws UnsupportedOperationException if the spec contains unknown partition transforms + */ + public static void validatePartitionTransforms(PartitionSpec spec) { + if (spec.fields().stream().anyMatch(field -> field.transform() instanceof UnknownTransform)) { + String unsupported = + spec.fields().stream() + .map(PartitionField::transform) + .filter(transform -> transform instanceof UnknownTransform) + .map(Transform::toString) + .collect(Collectors.joining(", ")); + + throw new UnsupportedOperationException( + String.format("Cannot write using unsupported transforms: %s", unsupported)); + } + } + + /** + * A modified version of Spark's LookupCatalog.CatalogAndIdentifier.unapply Attempts to find the + * catalog and identifier a multipart identifier represents + * + * @param nameParts Multipart identifier representing a table + * @return The CatalogPlugin and Identifier for the table + */ + public static Pair catalogAndIdentifier( + List nameParts, + Function catalogProvider, + BiFunction identiferProvider, + C currentCatalog, + String[] currentNamespace) { + Preconditions.checkArgument( + !nameParts.isEmpty(), "Cannot determine catalog and identifier from empty name"); + + int lastElementIndex = nameParts.size() - 1; + String name = nameParts.get(lastElementIndex); + + if (nameParts.size() == 1) { + // Only a single element, use current catalog and namespace + return Pair.of(currentCatalog, identiferProvider.apply(currentNamespace, name)); + } else { + C catalog = catalogProvider.apply(nameParts.get(0)); + if (catalog == null) { + // The first element was not a valid catalog, treat it like part of the namespace + String[] namespace = nameParts.subList(0, lastElementIndex).toArray(new String[0]); + return Pair.of(currentCatalog, identiferProvider.apply(namespace, name)); + } else { + // Assume the first element is a valid catalog + String[] namespace = nameParts.subList(1, lastElementIndex).toArray(new String[0]); + return Pair.of(catalog, identiferProvider.apply(namespace, name)); + } + } + } + + /** + * Pulls any Catalog specific overrides for the Hadoop conf from the current SparkSession, which + * can be set via `spark.sql.catalog.$catalogName.hadoop.*` + * + *

Mirrors the override of hadoop configurations for a given spark session using + * `spark.hadoop.*`. + * + *

The SparkCatalog allows for hadoop configurations to be overridden per catalog, by setting + * them on the SQLConf, where the following will add the property "fs.default.name" with value + * "hdfs://hanksnamenode:8020" to the catalog's hadoop configuration. SparkSession.builder() + * .config(s"spark.sql.catalog.$catalogName.hadoop.fs.default.name", "hdfs://hanksnamenode:8020") + * .getOrCreate() + * + * @param spark The current Spark session + * @param catalogName Name of the catalog to find overrides for. + * @return the Hadoop Configuration that should be used for this catalog, with catalog specific + * overrides applied. + */ + public static Configuration hadoopConfCatalogOverrides(SparkSession spark, String catalogName) { + // Find keys for the catalog intended to be hadoop configurations + final String hadoopConfCatalogPrefix = hadoopConfPrefixForCatalog(catalogName); + final Configuration conf = spark.sessionState().newHadoopConf(); + spark + .sqlContext() + .conf() + .settings() + .forEach( + (k, v) -> { + // these checks are copied from `spark.sessionState().newHadoopConfWithOptions()` + // to avoid converting back and forth between Scala / Java map types + if (v != null && k != null && k.startsWith(hadoopConfCatalogPrefix)) { + conf.set(k.substring(hadoopConfCatalogPrefix.length()), v); + } + }); + return conf; + } + + private static String hadoopConfPrefixForCatalog(String catalogName) { + return String.format(SPARK_CATALOG_HADOOP_CONF_OVERRIDE_FMT_STR, catalogName); + } + + /** + * Get a List of Spark filter Expression. + * + * @param schema table schema + * @param filters filters in the format of a Map, where key is one of the table column name, and + * value is the specific value to be filtered on the column. + * @return a List of filters in the format of Spark Expression. + */ + public static List partitionMapToExpression( + StructType schema, Map filters) { + List filterExpressions = Lists.newArrayList(); + for (Map.Entry entry : filters.entrySet()) { + try { + int index = schema.fieldIndex(entry.getKey()); + DataType dataType = schema.fields()[index].dataType(); + BoundReference ref = new BoundReference(index, dataType, true); + switch (dataType.typeName()) { + case "integer": + filterExpressions.add( + new EqualTo( + ref, + Literal.create(Integer.parseInt(entry.getValue()), DataTypes.IntegerType))); + break; + case "string": + filterExpressions.add( + new EqualTo(ref, Literal.create(entry.getValue(), DataTypes.StringType))); + break; + case "short": + filterExpressions.add( + new EqualTo( + ref, Literal.create(Short.parseShort(entry.getValue()), DataTypes.ShortType))); + break; + case "long": + filterExpressions.add( + new EqualTo( + ref, Literal.create(Long.parseLong(entry.getValue()), DataTypes.LongType))); + break; + case "float": + filterExpressions.add( + new EqualTo( + ref, Literal.create(Float.parseFloat(entry.getValue()), DataTypes.FloatType))); + break; + case "double": + filterExpressions.add( + new EqualTo( + ref, + Literal.create(Double.parseDouble(entry.getValue()), DataTypes.DoubleType))); + break; + case "date": + filterExpressions.add( + new EqualTo( + ref, + Literal.create( + new Date(DateTime.parse(entry.getValue()).getMillis()), + DataTypes.DateType))); + break; + case "timestamp": + filterExpressions.add( + new EqualTo( + ref, + Literal.create( + new Timestamp(DateTime.parse(entry.getValue()).getMillis()), + DataTypes.TimestampType))); + break; + default: + throw new IllegalStateException( + "Unexpected data type in partition filters: " + dataType); + } + } catch (IllegalArgumentException e) { + // ignore if filter is not on table columns + } + } + + return filterExpressions; + } + + public static String toColumnName(NamedReference ref) { + return DOT.join(ref.fieldNames()); + } + + public static boolean caseSensitive(SparkSession spark) { + return Boolean.parseBoolean(spark.conf().get("spark.sql.caseSensitive")); + } + + public static List executorLocations() { + BlockManager driverBlockManager = SparkEnv.get().blockManager(); + List executorBlockManagerIds = fetchPeers(driverBlockManager); + return executorBlockManagerIds.stream() + .map(SparkUtil::toExecutorLocation) + .sorted() + .collect(Collectors.toList()); + } + + private static List fetchPeers(BlockManager blockManager) { + BlockManagerMaster master = blockManager.master(); + BlockManagerId id = blockManager.blockManagerId(); + return toJavaList(master.getPeers(id)); + } + + private static List toJavaList(Seq seq) { + return JavaConverters.seqAsJavaListConverter(seq).asJava(); + } + + private static String toExecutorLocation(BlockManagerId id) { + return ExecutorCacheTaskLocation.apply(id.host(), id.executorId()).toString(); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java new file mode 100644 index 000000000000..57b9d61e38bd --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java @@ -0,0 +1,484 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.apache.iceberg.expressions.Expressions.and; +import static org.apache.iceberg.expressions.Expressions.bucket; +import static org.apache.iceberg.expressions.Expressions.day; +import static org.apache.iceberg.expressions.Expressions.equal; +import static org.apache.iceberg.expressions.Expressions.greaterThan; +import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual; +import static org.apache.iceberg.expressions.Expressions.hour; +import static org.apache.iceberg.expressions.Expressions.in; +import static org.apache.iceberg.expressions.Expressions.isNaN; +import static org.apache.iceberg.expressions.Expressions.isNull; +import static org.apache.iceberg.expressions.Expressions.lessThan; +import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual; +import static org.apache.iceberg.expressions.Expressions.month; +import static org.apache.iceberg.expressions.Expressions.not; +import static org.apache.iceberg.expressions.Expressions.notEqual; +import static org.apache.iceberg.expressions.Expressions.notIn; +import static org.apache.iceberg.expressions.Expressions.notNaN; +import static org.apache.iceberg.expressions.Expressions.notNull; +import static org.apache.iceberg.expressions.Expressions.or; +import static org.apache.iceberg.expressions.Expressions.startsWith; +import static org.apache.iceberg.expressions.Expressions.truncate; +import static org.apache.iceberg.expressions.Expressions.year; + +import java.util.Arrays; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expression.Operation; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.expressions.UnboundPredicate; +import org.apache.iceberg.expressions.UnboundTerm; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.util.NaNUtil; +import org.apache.iceberg.util.Pair; +import org.apache.spark.sql.connector.expressions.Literal; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.UserDefinedScalarFunc; +import org.apache.spark.sql.connector.expressions.filter.And; +import org.apache.spark.sql.connector.expressions.filter.Not; +import org.apache.spark.sql.connector.expressions.filter.Or; +import org.apache.spark.sql.connector.expressions.filter.Predicate; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.UTF8String; + +public class SparkV2Filters { + + public static final Set SUPPORTED_FUNCTIONS = + ImmutableSet.of("years", "months", "days", "hours", "bucket", "truncate"); + + private static final String TRUE = "ALWAYS_TRUE"; + private static final String FALSE = "ALWAYS_FALSE"; + private static final String EQ = "="; + private static final String EQ_NULL_SAFE = "<=>"; + private static final String NOT_EQ = "<>"; + private static final String GT = ">"; + private static final String GT_EQ = ">="; + private static final String LT = "<"; + private static final String LT_EQ = "<="; + private static final String IN = "IN"; + private static final String IS_NULL = "IS_NULL"; + private static final String NOT_NULL = "IS_NOT_NULL"; + private static final String AND = "AND"; + private static final String OR = "OR"; + private static final String NOT = "NOT"; + private static final String STARTS_WITH = "STARTS_WITH"; + + private static final Map FILTERS = + ImmutableMap.builder() + .put(TRUE, Operation.TRUE) + .put(FALSE, Operation.FALSE) + .put(EQ, Operation.EQ) + .put(EQ_NULL_SAFE, Operation.EQ) + .put(NOT_EQ, Operation.NOT_EQ) + .put(GT, Operation.GT) + .put(GT_EQ, Operation.GT_EQ) + .put(LT, Operation.LT) + .put(LT_EQ, Operation.LT_EQ) + .put(IN, Operation.IN) + .put(IS_NULL, Operation.IS_NULL) + .put(NOT_NULL, Operation.NOT_NULL) + .put(AND, Operation.AND) + .put(OR, Operation.OR) + .put(NOT, Operation.NOT) + .put(STARTS_WITH, Operation.STARTS_WITH) + .buildOrThrow(); + + private SparkV2Filters() {} + + public static Expression convert(Predicate[] predicates) { + Expression expression = Expressions.alwaysTrue(); + for (Predicate predicate : predicates) { + Expression converted = convert(predicate); + Preconditions.checkArgument( + converted != null, "Cannot convert Spark predicate to Iceberg expression: %s", predicate); + expression = Expressions.and(expression, converted); + } + + return expression; + } + + @SuppressWarnings({"checkstyle:CyclomaticComplexity", "checkstyle:MethodLength"}) + public static Expression convert(Predicate predicate) { + Operation op = FILTERS.get(predicate.name()); + if (op != null) { + switch (op) { + case TRUE: + return Expressions.alwaysTrue(); + + case FALSE: + return Expressions.alwaysFalse(); + + case IS_NULL: + if (canConvertToTerm(child(predicate))) { + UnboundTerm term = toTerm(child(predicate)); + return term != null ? isNull(term) : null; + } + + return null; + + case NOT_NULL: + if (canConvertToTerm(child(predicate))) { + UnboundTerm term = toTerm(child(predicate)); + return term != null ? notNull(term) : null; + } + + return null; + + case LT: + if (canConvertToTerm(leftChild(predicate)) && isLiteral(rightChild(predicate))) { + UnboundTerm term = toTerm(leftChild(predicate)); + return term != null ? lessThan(term, convertLiteral(rightChild(predicate))) : null; + } else if (canConvertToTerm(rightChild(predicate)) && isLiteral(leftChild(predicate))) { + UnboundTerm term = toTerm(rightChild(predicate)); + return term != null ? greaterThan(term, convertLiteral(leftChild(predicate))) : null; + } else { + return null; + } + + case LT_EQ: + if (canConvertToTerm(leftChild(predicate)) && isLiteral(rightChild(predicate))) { + UnboundTerm term = toTerm(leftChild(predicate)); + return term != null + ? lessThanOrEqual(term, convertLiteral(rightChild(predicate))) + : null; + } else if (canConvertToTerm(rightChild(predicate)) && isLiteral(leftChild(predicate))) { + UnboundTerm term = toTerm(rightChild(predicate)); + return term != null + ? greaterThanOrEqual(term, convertLiteral(leftChild(predicate))) + : null; + } else { + return null; + } + + case GT: + if (canConvertToTerm(leftChild(predicate)) && isLiteral(rightChild(predicate))) { + UnboundTerm term = toTerm(leftChild(predicate)); + return term != null ? greaterThan(term, convertLiteral(rightChild(predicate))) : null; + } else if (canConvertToTerm(rightChild(predicate)) && isLiteral(leftChild(predicate))) { + UnboundTerm term = toTerm(rightChild(predicate)); + return term != null ? lessThan(term, convertLiteral(leftChild(predicate))) : null; + } else { + return null; + } + + case GT_EQ: + if (canConvertToTerm(leftChild(predicate)) && isLiteral(rightChild(predicate))) { + UnboundTerm term = toTerm(leftChild(predicate)); + return term != null + ? greaterThanOrEqual(term, convertLiteral(rightChild(predicate))) + : null; + } else if (canConvertToTerm(rightChild(predicate)) && isLiteral(leftChild(predicate))) { + UnboundTerm term = toTerm(rightChild(predicate)); + return term != null + ? lessThanOrEqual(term, convertLiteral(leftChild(predicate))) + : null; + } else { + return null; + } + + case EQ: // used for both eq and null-safe-eq + Pair, Object> eqChildren = predicateChildren(predicate); + if (eqChildren == null) { + return null; + } + + if (predicate.name().equals(EQ)) { + // comparison with null in normal equality is always null. this is probably a mistake. + Preconditions.checkNotNull( + eqChildren.second(), + "Expression is always false (eq is not null-safe): %s", + predicate); + } + + return handleEqual(eqChildren.first(), eqChildren.second()); + + case NOT_EQ: + Pair, Object> notEqChildren = predicateChildren(predicate); + if (notEqChildren == null) { + return null; + } + + // comparison with null in normal equality is always null. this is probably a mistake. + Preconditions.checkNotNull( + notEqChildren.second(), + "Expression is always false (notEq is not null-safe): %s", + predicate); + + return handleNotEqual(notEqChildren.first(), notEqChildren.second()); + + case IN: + if (isSupportedInPredicate(predicate)) { + UnboundTerm term = toTerm(childAtIndex(predicate, 0)); + + return term != null + ? in( + term, + Arrays.stream(predicate.children()) + .skip(1) + .map(val -> convertLiteral(((Literal) val))) + .filter(Objects::nonNull) + .collect(Collectors.toList())) + : null; + } else { + return null; + } + + case NOT: + Not notPredicate = (Not) predicate; + Predicate childPredicate = notPredicate.child(); + if (childPredicate.name().equals(IN) && isSupportedInPredicate(childPredicate)) { + UnboundTerm term = toTerm(childAtIndex(childPredicate, 0)); + if (term == null) { + return null; + } + + // infer an extra notNull predicate for Spark NOT IN filters + // as Iceberg expressions don't follow the 3-value SQL boolean logic + // col NOT IN (1, 2) in Spark is equal to notNull(col) && notIn(col, 1, 2) in Iceberg + Expression notIn = + notIn( + term, + Arrays.stream(childPredicate.children()) + .skip(1) + .map(val -> convertLiteral(((Literal) val))) + .filter(Objects::nonNull) + .collect(Collectors.toList())); + return and(notNull(term), notIn); + } else if (hasNoInFilter(childPredicate)) { + Expression child = convert(childPredicate); + if (child != null) { + return not(child); + } + } + return null; + + case AND: + { + And andPredicate = (And) predicate; + Expression left = convert(andPredicate.left()); + Expression right = convert(andPredicate.right()); + if (left != null && right != null) { + return and(left, right); + } + return null; + } + + case OR: + { + Or orPredicate = (Or) predicate; + Expression left = convert(orPredicate.left()); + Expression right = convert(orPredicate.right()); + if (left != null && right != null) { + return or(left, right); + } + return null; + } + + case STARTS_WITH: + String colName = SparkUtil.toColumnName(leftChild(predicate)); + return startsWith(colName, convertLiteral(rightChild(predicate)).toString()); + } + } + + return null; + } + + private static Pair, Object> predicateChildren(Predicate predicate) { + if (canConvertToTerm(leftChild(predicate)) && isLiteral(rightChild(predicate))) { + UnboundTerm term = toTerm(leftChild(predicate)); + return term != null ? Pair.of(term, convertLiteral(rightChild(predicate))) : null; + + } else if (canConvertToTerm(rightChild(predicate)) && isLiteral(leftChild(predicate))) { + UnboundTerm term = toTerm(rightChild(predicate)); + return term != null ? Pair.of(term, convertLiteral(leftChild(predicate))) : null; + + } else { + return null; + } + } + + @SuppressWarnings("unchecked") + private static T child(Predicate predicate) { + org.apache.spark.sql.connector.expressions.Expression[] children = predicate.children(); + Preconditions.checkArgument( + children.length == 1, "Predicate should have one child: %s", predicate); + return (T) children[0]; + } + + @SuppressWarnings("unchecked") + private static T leftChild(Predicate predicate) { + org.apache.spark.sql.connector.expressions.Expression[] children = predicate.children(); + Preconditions.checkArgument( + children.length == 2, "Predicate should have two children: %s", predicate); + return (T) children[0]; + } + + @SuppressWarnings("unchecked") + private static T rightChild(Predicate predicate) { + org.apache.spark.sql.connector.expressions.Expression[] children = predicate.children(); + Preconditions.checkArgument( + children.length == 2, "Predicate should have two children: %s", predicate); + return (T) children[1]; + } + + @SuppressWarnings("unchecked") + private static T childAtIndex(Predicate predicate, int index) { + return (T) predicate.children()[index]; + } + + private static boolean canConvertToTerm( + org.apache.spark.sql.connector.expressions.Expression expr) { + return isRef(expr) || isSystemFunc(expr); + } + + private static boolean isRef(org.apache.spark.sql.connector.expressions.Expression expr) { + return expr instanceof NamedReference; + } + + private static boolean isSystemFunc(org.apache.spark.sql.connector.expressions.Expression expr) { + if (expr instanceof UserDefinedScalarFunc) { + UserDefinedScalarFunc udf = (UserDefinedScalarFunc) expr; + return udf.canonicalName().startsWith("iceberg") + && SUPPORTED_FUNCTIONS.contains(udf.name()) + && Arrays.stream(udf.children()).allMatch(child -> isLiteral(child) || isRef(child)); + } + + return false; + } + + private static boolean isLiteral(org.apache.spark.sql.connector.expressions.Expression expr) { + return expr instanceof Literal; + } + + private static Object convertLiteral(Literal literal) { + if (literal.value() instanceof UTF8String) { + return ((UTF8String) literal.value()).toString(); + } else if (literal.value() instanceof Decimal) { + return ((Decimal) literal.value()).toJavaBigDecimal(); + } + return literal.value(); + } + + private static UnboundPredicate handleEqual(UnboundTerm term, Object value) { + if (value == null) { + return isNull(term); + } else if (NaNUtil.isNaN(value)) { + return isNaN(term); + } else { + return equal(term, value); + } + } + + private static UnboundPredicate handleNotEqual(UnboundTerm term, Object value) { + if (NaNUtil.isNaN(value)) { + return notNaN(term); + } else { + return notEqual(term, value); + } + } + + private static boolean hasNoInFilter(Predicate predicate) { + Operation op = FILTERS.get(predicate.name()); + + if (op != null) { + switch (op) { + case AND: + And andPredicate = (And) predicate; + return hasNoInFilter(andPredicate.left()) && hasNoInFilter(andPredicate.right()); + case OR: + Or orPredicate = (Or) predicate; + return hasNoInFilter(orPredicate.left()) && hasNoInFilter(orPredicate.right()); + case NOT: + Not notPredicate = (Not) predicate; + return hasNoInFilter(notPredicate.child()); + case IN: + return false; + default: + return true; + } + } + + return false; + } + + private static boolean isSupportedInPredicate(Predicate predicate) { + if (!canConvertToTerm(childAtIndex(predicate, 0))) { + return false; + } else { + return Arrays.stream(predicate.children()).skip(1).allMatch(SparkV2Filters::isLiteral); + } + } + + /** Should be called after {@link #canConvertToTerm} passed */ + private static UnboundTerm toTerm(T input) { + if (input instanceof NamedReference) { + return Expressions.ref(SparkUtil.toColumnName((NamedReference) input)); + } else if (input instanceof UserDefinedScalarFunc) { + return udfToTerm((UserDefinedScalarFunc) input); + } else { + return null; + } + } + + @SuppressWarnings("checkstyle:CyclomaticComplexity") + private static UnboundTerm udfToTerm(UserDefinedScalarFunc udf) { + org.apache.spark.sql.connector.expressions.Expression[] children = udf.children(); + String udfName = udf.name().toLowerCase(Locale.ROOT); + if (children.length == 1) { + org.apache.spark.sql.connector.expressions.Expression child = children[0]; + if (isRef(child)) { + String column = SparkUtil.toColumnName((NamedReference) child); + switch (udfName) { + case "years": + return year(column); + case "months": + return month(column); + case "days": + return day(column); + case "hours": + return hour(column); + } + } + } else if (children.length == 2) { + if (isLiteral(children[0]) && isRef(children[1])) { + String column = SparkUtil.toColumnName((NamedReference) children[1]); + switch (udfName) { + case "bucket": + int numBuckets = (Integer) convertLiteral((Literal) children[0]); + return bucket(column, numBuckets); + case "truncate": + int width = (Integer) convertLiteral((Literal) children[0]); + return truncate(column, width); + } + } + } + + return null; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkValueConverter.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkValueConverter.java new file mode 100644 index 000000000000..28b717ac090e --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkValueConverter.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.Schema; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.ByteBuffers; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.util.DateTimeUtils; + +/** A utility class that converts Spark values to Iceberg's internal representation. */ +public class SparkValueConverter { + + private SparkValueConverter() {} + + public static Record convert(Schema schema, Row row) { + return convert(schema.asStruct(), row); + } + + public static Object convert(Type type, Object object) { + if (object == null) { + return null; + } + + switch (type.typeId()) { + case STRUCT: + return convert(type.asStructType(), (Row) object); + + case LIST: + List convertedList = Lists.newArrayList(); + List list = (List) object; + for (Object element : list) { + convertedList.add(convert(type.asListType().elementType(), element)); + } + return convertedList; + + case MAP: + Map convertedMap = Maps.newLinkedHashMap(); + Map map = (Map) object; + for (Map.Entry entry : map.entrySet()) { + convertedMap.put( + convert(type.asMapType().keyType(), entry.getKey()), + convert(type.asMapType().valueType(), entry.getValue())); + } + return convertedMap; + + case DATE: + // if spark.sql.datetime.java8API.enabled is set to true, java.time.LocalDate + // for Spark SQL DATE type otherwise java.sql.Date is returned. + return DateTimeUtils.anyToDays(object); + case TIMESTAMP: + return DateTimeUtils.anyToMicros(object); + case BINARY: + return ByteBuffer.wrap((byte[]) object); + case INTEGER: + return ((Number) object).intValue(); + case BOOLEAN: + case LONG: + case FLOAT: + case DOUBLE: + case DECIMAL: + case STRING: + case FIXED: + return object; + default: + throw new UnsupportedOperationException("Not a supported type: " + type); + } + } + + private static Record convert(Types.StructType struct, Row row) { + if (row == null) { + return null; + } + + Record record = GenericRecord.create(struct); + List fields = struct.fields(); + for (int i = 0; i < fields.size(); i += 1) { + Types.NestedField field = fields.get(i); + + Type fieldType = field.type(); + + switch (fieldType.typeId()) { + case STRUCT: + record.set(i, convert(fieldType.asStructType(), row.getStruct(i))); + break; + case LIST: + record.set(i, convert(fieldType.asListType(), row.getList(i))); + break; + case MAP: + record.set(i, convert(fieldType.asMapType(), row.getJavaMap(i))); + break; + default: + record.set(i, convert(fieldType, row.get(i))); + } + } + return record; + } + + public static Object convertToSpark(Type type, Object object) { + if (object == null) { + return null; + } + + switch (type.typeId()) { + case STRUCT: + case LIST: + case MAP: + return new UnsupportedOperationException("Complex types currently not supported"); + case DATE: + return DateTimeUtils.daysToLocalDate((int) object); + case TIMESTAMP: + Types.TimestampType ts = (Types.TimestampType) type.asPrimitiveType(); + if (ts.shouldAdjustToUTC()) { + return DateTimeUtils.microsToInstant((long) object); + } else { + return DateTimeUtils.microsToLocalDateTime((long) object); + } + case BINARY: + return ByteBuffers.toByteArray((ByteBuffer) object); + case INTEGER: + case BOOLEAN: + case LONG: + case FLOAT: + case DOUBLE: + case DECIMAL: + case STRING: + case FIXED: + return object; + default: + throw new UnsupportedOperationException("Not a supported type: " + type); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java new file mode 100644 index 000000000000..2c8c26d80977 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java @@ -0,0 +1,724 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.apache.iceberg.DistributionMode.HASH; +import static org.apache.iceberg.DistributionMode.NONE; +import static org.apache.iceberg.DistributionMode.RANGE; +import static org.apache.iceberg.TableProperties.AVRO_COMPRESSION; +import static org.apache.iceberg.TableProperties.AVRO_COMPRESSION_LEVEL; +import static org.apache.iceberg.TableProperties.DELETE_AVRO_COMPRESSION; +import static org.apache.iceberg.TableProperties.DELETE_AVRO_COMPRESSION_LEVEL; +import static org.apache.iceberg.TableProperties.DELETE_ORC_COMPRESSION; +import static org.apache.iceberg.TableProperties.DELETE_ORC_COMPRESSION_STRATEGY; +import static org.apache.iceberg.TableProperties.DELETE_PARQUET_COMPRESSION; +import static org.apache.iceberg.TableProperties.DELETE_PARQUET_COMPRESSION_LEVEL; +import static org.apache.iceberg.TableProperties.ORC_COMPRESSION; +import static org.apache.iceberg.TableProperties.ORC_COMPRESSION_STRATEGY; +import static org.apache.iceberg.TableProperties.PARQUET_COMPRESSION; +import static org.apache.iceberg.TableProperties.PARQUET_COMPRESSION_LEVEL; +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE; + +import java.util.Locale; +import java.util.Map; +import org.apache.iceberg.DistributionMode; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.IsolationLevel; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.deletes.DeleteGranularity; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.spark.sql.RuntimeConfig; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.write.RowLevelOperation.Command; +import org.apache.spark.sql.internal.SQLConf; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A class for common Iceberg configs for Spark writes. + * + *

If a config is set at multiple levels, the following order of precedence is used (top to + * bottom): + * + *

    + *
  1. Write options + *
  2. Session configuration + *
  3. Table metadata + *
+ * + * The most specific value is set in write options and takes precedence over all other configs. If + * no write option is provided, this class checks the session configuration for any overrides. If no + * applicable value is found in the session configuration, this class uses the table metadata. + * + *

Note this class is NOT meant to be serialized and sent to executors. + */ +public class SparkWriteConf { + + private static final Logger LOG = LoggerFactory.getLogger(SparkWriteConf.class); + + private static final long DATA_FILE_SIZE = 128 * 1024 * 1024; // 128 MB + private static final long DELETE_FILE_SIZE = 32 * 1024 * 1024; // 32 MB + + private final SparkSession spark; + private final Table table; + private final String branch; + private final RuntimeConfig sessionConf; + private final Map writeOptions; + private final SparkConfParser confParser; + + public SparkWriteConf(SparkSession spark, Table table, Map writeOptions) { + this(spark, table, null, writeOptions); + } + + public SparkWriteConf( + SparkSession spark, Table table, String branch, Map writeOptions) { + this.spark = spark; + this.table = table; + this.branch = branch; + this.sessionConf = spark.conf(); + this.writeOptions = writeOptions; + this.confParser = new SparkConfParser(spark, table, writeOptions); + } + + public boolean checkNullability() { + return confParser + .booleanConf() + .option(SparkWriteOptions.CHECK_NULLABILITY) + .sessionConf(SparkSQLProperties.CHECK_NULLABILITY) + .defaultValue(SparkSQLProperties.CHECK_NULLABILITY_DEFAULT) + .parse(); + } + + public boolean checkOrdering() { + return confParser + .booleanConf() + .option(SparkWriteOptions.CHECK_ORDERING) + .sessionConf(SparkSQLProperties.CHECK_ORDERING) + .defaultValue(SparkSQLProperties.CHECK_ORDERING_DEFAULT) + .parse(); + } + + public String overwriteMode() { + String overwriteMode = writeOptions.get(SparkWriteOptions.OVERWRITE_MODE); + return overwriteMode != null ? overwriteMode.toLowerCase(Locale.ROOT) : null; + } + + public boolean wapEnabled() { + return confParser + .booleanConf() + .tableProperty(TableProperties.WRITE_AUDIT_PUBLISH_ENABLED) + .defaultValue(TableProperties.WRITE_AUDIT_PUBLISH_ENABLED_DEFAULT) + .parse(); + } + + public String wapId() { + return sessionConf.get(SparkSQLProperties.WAP_ID, null); + } + + public boolean mergeSchema() { + return confParser + .booleanConf() + .option(SparkWriteOptions.MERGE_SCHEMA) + .option(SparkWriteOptions.SPARK_MERGE_SCHEMA) + .sessionConf(SparkSQLProperties.MERGE_SCHEMA) + .defaultValue(SparkSQLProperties.MERGE_SCHEMA_DEFAULT) + .parse(); + } + + public int outputSpecId() { + int outputSpecId = + confParser + .intConf() + .option(SparkWriteOptions.OUTPUT_SPEC_ID) + .defaultValue(table.spec().specId()) + .parse(); + Preconditions.checkArgument( + table.specs().containsKey(outputSpecId), + "Output spec id %s is not a valid spec id for table", + outputSpecId); + return outputSpecId; + } + + public FileFormat dataFileFormat() { + String valueAsString = + confParser + .stringConf() + .option(SparkWriteOptions.WRITE_FORMAT) + .tableProperty(TableProperties.DEFAULT_FILE_FORMAT) + .defaultValue(TableProperties.DEFAULT_FILE_FORMAT_DEFAULT) + .parse(); + return FileFormat.fromString(valueAsString); + } + + private String dataCompressionCodec() { + switch (dataFileFormat()) { + case PARQUET: + return parquetCompressionCodec(); + case AVRO: + return avroCompressionCodec(); + case ORC: + return orcCompressionCodec(); + default: + return null; + } + } + + public long targetDataFileSize() { + return confParser + .longConf() + .option(SparkWriteOptions.TARGET_FILE_SIZE_BYTES) + .tableProperty(TableProperties.WRITE_TARGET_FILE_SIZE_BYTES) + .defaultValue(TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT) + .parse(); + } + + public boolean useFanoutWriter(SparkWriteRequirements writeRequirements) { + boolean defaultValue = !writeRequirements.hasOrdering(); + return fanoutWriterEnabled(defaultValue); + } + + private boolean fanoutWriterEnabled() { + return fanoutWriterEnabled(true /* enabled by default */); + } + + private boolean fanoutWriterEnabled(boolean defaultValue) { + return confParser + .booleanConf() + .option(SparkWriteOptions.FANOUT_ENABLED) + .tableProperty(TableProperties.SPARK_WRITE_PARTITIONED_FANOUT_ENABLED) + .defaultValue(defaultValue) + .parse(); + } + + public FileFormat deleteFileFormat() { + String valueAsString = + confParser + .stringConf() + .option(SparkWriteOptions.DELETE_FORMAT) + .tableProperty(TableProperties.DELETE_DEFAULT_FILE_FORMAT) + .parseOptional(); + return valueAsString != null ? FileFormat.fromString(valueAsString) : dataFileFormat(); + } + + private String deleteCompressionCodec() { + switch (deleteFileFormat()) { + case PARQUET: + return deleteParquetCompressionCodec(); + case AVRO: + return deleteAvroCompressionCodec(); + case ORC: + return deleteOrcCompressionCodec(); + default: + return null; + } + } + + public long targetDeleteFileSize() { + return confParser + .longConf() + .option(SparkWriteOptions.TARGET_DELETE_FILE_SIZE_BYTES) + .tableProperty(TableProperties.DELETE_TARGET_FILE_SIZE_BYTES) + .defaultValue(TableProperties.DELETE_TARGET_FILE_SIZE_BYTES_DEFAULT) + .parse(); + } + + public Map extraSnapshotMetadata() { + Map extraSnapshotMetadata = Maps.newHashMap(); + + writeOptions.forEach( + (key, value) -> { + if (key.startsWith(SnapshotSummary.EXTRA_METADATA_PREFIX)) { + extraSnapshotMetadata.put( + key.substring(SnapshotSummary.EXTRA_METADATA_PREFIX.length()), value); + } + }); + + return extraSnapshotMetadata; + } + + public String rewrittenFileSetId() { + return confParser + .stringConf() + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID) + .parseOptional(); + } + + public SparkWriteRequirements writeRequirements() { + if (ignoreTableDistributionAndOrdering()) { + LOG.info("Skipping distribution/ordering: disabled per job configuration"); + return SparkWriteRequirements.EMPTY; + } + + return SparkWriteUtil.writeRequirements( + table, distributionMode(), fanoutWriterEnabled(), dataAdvisoryPartitionSize()); + } + + @VisibleForTesting + DistributionMode distributionMode() { + String modeName = + confParser + .stringConf() + .option(SparkWriteOptions.DISTRIBUTION_MODE) + .sessionConf(SparkSQLProperties.DISTRIBUTION_MODE) + .tableProperty(TableProperties.WRITE_DISTRIBUTION_MODE) + .parseOptional(); + + if (modeName != null) { + DistributionMode mode = DistributionMode.fromName(modeName); + return adjustWriteDistributionMode(mode); + } else { + return defaultWriteDistributionMode(); + } + } + + private DistributionMode adjustWriteDistributionMode(DistributionMode mode) { + if (mode == RANGE && table.spec().isUnpartitioned() && table.sortOrder().isUnsorted()) { + return NONE; + } else if (mode == HASH && table.spec().isUnpartitioned()) { + return NONE; + } else { + return mode; + } + } + + private DistributionMode defaultWriteDistributionMode() { + if (table.sortOrder().isSorted()) { + return RANGE; + } else if (table.spec().isPartitioned()) { + return HASH; + } else { + return NONE; + } + } + + public SparkWriteRequirements copyOnWriteRequirements(Command command) { + if (ignoreTableDistributionAndOrdering()) { + LOG.info("Skipping distribution/ordering: disabled per job configuration"); + return SparkWriteRequirements.EMPTY; + } + + return SparkWriteUtil.copyOnWriteRequirements( + table, + command, + copyOnWriteDistributionMode(command), + fanoutWriterEnabled(), + dataAdvisoryPartitionSize()); + } + + @VisibleForTesting + DistributionMode copyOnWriteDistributionMode(Command command) { + switch (command) { + case DELETE: + return deleteDistributionMode(); + case UPDATE: + return updateDistributionMode(); + case MERGE: + return copyOnWriteMergeDistributionMode(); + default: + throw new IllegalArgumentException("Unexpected command: " + command); + } + } + + public SparkWriteRequirements positionDeltaRequirements(Command command) { + if (ignoreTableDistributionAndOrdering()) { + LOG.info("Skipping distribution/ordering: disabled per job configuration"); + return SparkWriteRequirements.EMPTY; + } + + return SparkWriteUtil.positionDeltaRequirements( + table, + command, + positionDeltaDistributionMode(command), + fanoutWriterEnabled(), + command == DELETE ? deleteAdvisoryPartitionSize() : dataAdvisoryPartitionSize()); + } + + @VisibleForTesting + DistributionMode positionDeltaDistributionMode(Command command) { + switch (command) { + case DELETE: + return deleteDistributionMode(); + case UPDATE: + return updateDistributionMode(); + case MERGE: + return positionDeltaMergeDistributionMode(); + default: + throw new IllegalArgumentException("Unexpected command: " + command); + } + } + + private DistributionMode deleteDistributionMode() { + String deleteModeName = + confParser + .stringConf() + .option(SparkWriteOptions.DISTRIBUTION_MODE) + .sessionConf(SparkSQLProperties.DISTRIBUTION_MODE) + .tableProperty(TableProperties.DELETE_DISTRIBUTION_MODE) + .defaultValue(TableProperties.WRITE_DISTRIBUTION_MODE_HASH) + .parse(); + return DistributionMode.fromName(deleteModeName); + } + + private DistributionMode updateDistributionMode() { + String updateModeName = + confParser + .stringConf() + .option(SparkWriteOptions.DISTRIBUTION_MODE) + .sessionConf(SparkSQLProperties.DISTRIBUTION_MODE) + .tableProperty(TableProperties.UPDATE_DISTRIBUTION_MODE) + .defaultValue(TableProperties.WRITE_DISTRIBUTION_MODE_HASH) + .parse(); + return DistributionMode.fromName(updateModeName); + } + + private DistributionMode copyOnWriteMergeDistributionMode() { + String mergeModeName = + confParser + .stringConf() + .option(SparkWriteOptions.DISTRIBUTION_MODE) + .sessionConf(SparkSQLProperties.DISTRIBUTION_MODE) + .tableProperty(TableProperties.MERGE_DISTRIBUTION_MODE) + .parseOptional(); + + if (mergeModeName != null) { + DistributionMode mergeMode = DistributionMode.fromName(mergeModeName); + return adjustWriteDistributionMode(mergeMode); + + } else if (table.spec().isPartitioned()) { + return HASH; + + } else { + return distributionMode(); + } + } + + private DistributionMode positionDeltaMergeDistributionMode() { + String mergeModeName = + confParser + .stringConf() + .option(SparkWriteOptions.DISTRIBUTION_MODE) + .sessionConf(SparkSQLProperties.DISTRIBUTION_MODE) + .tableProperty(TableProperties.MERGE_DISTRIBUTION_MODE) + .defaultValue(TableProperties.WRITE_DISTRIBUTION_MODE_HASH) + .parse(); + return DistributionMode.fromName(mergeModeName); + } + + private boolean ignoreTableDistributionAndOrdering() { + return confParser + .booleanConf() + .option(SparkWriteOptions.USE_TABLE_DISTRIBUTION_AND_ORDERING) + .defaultValue(SparkWriteOptions.USE_TABLE_DISTRIBUTION_AND_ORDERING_DEFAULT) + .negate() + .parse(); + } + + public Long validateFromSnapshotId() { + return confParser + .longConf() + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID) + .parseOptional(); + } + + public IsolationLevel isolationLevel() { + String isolationLevelName = + confParser.stringConf().option(SparkWriteOptions.ISOLATION_LEVEL).parseOptional(); + return isolationLevelName != null ? IsolationLevel.fromName(isolationLevelName) : null; + } + + public boolean caseSensitive() { + return confParser + .booleanConf() + .sessionConf(SQLConf.CASE_SENSITIVE().key()) + .defaultValue(SQLConf.CASE_SENSITIVE().defaultValueString()) + .parse(); + } + + public String branch() { + if (wapEnabled()) { + String wapId = wapId(); + String wapBranch = + confParser.stringConf().sessionConf(SparkSQLProperties.WAP_BRANCH).parseOptional(); + + ValidationException.check( + wapId == null || wapBranch == null, + "Cannot set both WAP ID and branch, but got ID [%s] and branch [%s]", + wapId, + wapBranch); + + if (wapBranch != null) { + ValidationException.check( + branch == null, + "Cannot write to both branch and WAP branch, but got branch [%s] and WAP branch [%s]", + branch, + wapBranch); + + return wapBranch; + } + } + + return branch; + } + + public Map writeProperties() { + Map writeProperties = Maps.newHashMap(); + writeProperties.putAll(dataWriteProperties()); + writeProperties.putAll(deleteWriteProperties()); + return writeProperties; + } + + private Map dataWriteProperties() { + Map writeProperties = Maps.newHashMap(); + FileFormat dataFormat = dataFileFormat(); + + switch (dataFormat) { + case PARQUET: + writeProperties.put(PARQUET_COMPRESSION, parquetCompressionCodec()); + String parquetCompressionLevel = parquetCompressionLevel(); + if (parquetCompressionLevel != null) { + writeProperties.put(PARQUET_COMPRESSION_LEVEL, parquetCompressionLevel); + } + break; + + case AVRO: + writeProperties.put(AVRO_COMPRESSION, avroCompressionCodec()); + String avroCompressionLevel = avroCompressionLevel(); + if (avroCompressionLevel != null) { + writeProperties.put(AVRO_COMPRESSION_LEVEL, avroCompressionLevel); + } + break; + + case ORC: + writeProperties.put(ORC_COMPRESSION, orcCompressionCodec()); + writeProperties.put(ORC_COMPRESSION_STRATEGY, orcCompressionStrategy()); + break; + + default: + // skip + } + + return writeProperties; + } + + private Map deleteWriteProperties() { + Map writeProperties = Maps.newHashMap(); + FileFormat deleteFormat = deleteFileFormat(); + + switch (deleteFormat) { + case PARQUET: + writeProperties.put(DELETE_PARQUET_COMPRESSION, deleteParquetCompressionCodec()); + String deleteParquetCompressionLevel = deleteParquetCompressionLevel(); + if (deleteParquetCompressionLevel != null) { + writeProperties.put(DELETE_PARQUET_COMPRESSION_LEVEL, deleteParquetCompressionLevel); + } + break; + + case AVRO: + writeProperties.put(DELETE_AVRO_COMPRESSION, deleteAvroCompressionCodec()); + String deleteAvroCompressionLevel = deleteAvroCompressionLevel(); + if (deleteAvroCompressionLevel != null) { + writeProperties.put(DELETE_AVRO_COMPRESSION_LEVEL, deleteAvroCompressionLevel); + } + break; + + case ORC: + writeProperties.put(DELETE_ORC_COMPRESSION, deleteOrcCompressionCodec()); + writeProperties.put(DELETE_ORC_COMPRESSION_STRATEGY, deleteOrcCompressionStrategy()); + break; + + default: + // skip + } + + return writeProperties; + } + + private String parquetCompressionCodec() { + return confParser + .stringConf() + .option(SparkWriteOptions.COMPRESSION_CODEC) + .sessionConf(SparkSQLProperties.COMPRESSION_CODEC) + .tableProperty(TableProperties.PARQUET_COMPRESSION) + .defaultValue(TableProperties.PARQUET_COMPRESSION_DEFAULT) + .parse(); + } + + private String deleteParquetCompressionCodec() { + return confParser + .stringConf() + .option(SparkWriteOptions.COMPRESSION_CODEC) + .sessionConf(SparkSQLProperties.COMPRESSION_CODEC) + .tableProperty(DELETE_PARQUET_COMPRESSION) + .defaultValue(parquetCompressionCodec()) + .parse(); + } + + private String parquetCompressionLevel() { + return confParser + .stringConf() + .option(SparkWriteOptions.COMPRESSION_LEVEL) + .sessionConf(SparkSQLProperties.COMPRESSION_LEVEL) + .tableProperty(TableProperties.PARQUET_COMPRESSION_LEVEL) + .defaultValue(TableProperties.PARQUET_COMPRESSION_LEVEL_DEFAULT) + .parseOptional(); + } + + private String deleteParquetCompressionLevel() { + return confParser + .stringConf() + .option(SparkWriteOptions.COMPRESSION_LEVEL) + .sessionConf(SparkSQLProperties.COMPRESSION_LEVEL) + .tableProperty(DELETE_PARQUET_COMPRESSION_LEVEL) + .defaultValue(parquetCompressionLevel()) + .parseOptional(); + } + + private String avroCompressionCodec() { + return confParser + .stringConf() + .option(SparkWriteOptions.COMPRESSION_CODEC) + .sessionConf(SparkSQLProperties.COMPRESSION_CODEC) + .tableProperty(TableProperties.AVRO_COMPRESSION) + .defaultValue(TableProperties.AVRO_COMPRESSION_DEFAULT) + .parse(); + } + + private String deleteAvroCompressionCodec() { + return confParser + .stringConf() + .option(SparkWriteOptions.COMPRESSION_CODEC) + .sessionConf(SparkSQLProperties.COMPRESSION_CODEC) + .tableProperty(DELETE_AVRO_COMPRESSION) + .defaultValue(avroCompressionCodec()) + .parse(); + } + + private String avroCompressionLevel() { + return confParser + .stringConf() + .option(SparkWriteOptions.COMPRESSION_LEVEL) + .sessionConf(SparkSQLProperties.COMPRESSION_LEVEL) + .tableProperty(TableProperties.AVRO_COMPRESSION_LEVEL) + .defaultValue(TableProperties.AVRO_COMPRESSION_LEVEL_DEFAULT) + .parseOptional(); + } + + private String deleteAvroCompressionLevel() { + return confParser + .stringConf() + .option(SparkWriteOptions.COMPRESSION_LEVEL) + .sessionConf(SparkSQLProperties.COMPRESSION_LEVEL) + .tableProperty(DELETE_AVRO_COMPRESSION_LEVEL) + .defaultValue(avroCompressionLevel()) + .parseOptional(); + } + + private String orcCompressionCodec() { + return confParser + .stringConf() + .option(SparkWriteOptions.COMPRESSION_CODEC) + .sessionConf(SparkSQLProperties.COMPRESSION_CODEC) + .tableProperty(TableProperties.ORC_COMPRESSION) + .defaultValue(TableProperties.ORC_COMPRESSION_DEFAULT) + .parse(); + } + + private String deleteOrcCompressionCodec() { + return confParser + .stringConf() + .option(SparkWriteOptions.COMPRESSION_CODEC) + .sessionConf(SparkSQLProperties.COMPRESSION_CODEC) + .tableProperty(DELETE_ORC_COMPRESSION) + .defaultValue(orcCompressionCodec()) + .parse(); + } + + private String orcCompressionStrategy() { + return confParser + .stringConf() + .option(SparkWriteOptions.COMPRESSION_STRATEGY) + .sessionConf(SparkSQLProperties.COMPRESSION_STRATEGY) + .tableProperty(TableProperties.ORC_COMPRESSION_STRATEGY) + .defaultValue(TableProperties.ORC_COMPRESSION_STRATEGY_DEFAULT) + .parse(); + } + + private String deleteOrcCompressionStrategy() { + return confParser + .stringConf() + .option(SparkWriteOptions.COMPRESSION_STRATEGY) + .sessionConf(SparkSQLProperties.COMPRESSION_STRATEGY) + .tableProperty(DELETE_ORC_COMPRESSION_STRATEGY) + .defaultValue(orcCompressionStrategy()) + .parse(); + } + + private long dataAdvisoryPartitionSize() { + long defaultValue = + advisoryPartitionSize(DATA_FILE_SIZE, dataFileFormat(), dataCompressionCodec()); + return advisoryPartitionSize(defaultValue); + } + + private long deleteAdvisoryPartitionSize() { + long defaultValue = + advisoryPartitionSize(DELETE_FILE_SIZE, deleteFileFormat(), deleteCompressionCodec()); + return advisoryPartitionSize(defaultValue); + } + + private long advisoryPartitionSize(long defaultValue) { + return confParser + .longConf() + .option(SparkWriteOptions.ADVISORY_PARTITION_SIZE) + .sessionConf(SparkSQLProperties.ADVISORY_PARTITION_SIZE) + .tableProperty(TableProperties.SPARK_WRITE_ADVISORY_PARTITION_SIZE_BYTES) + .defaultValue(defaultValue) + .parse(); + } + + private long advisoryPartitionSize( + long expectedFileSize, FileFormat outputFileFormat, String outputCodec) { + double shuffleCompressionRatio = shuffleCompressionRatio(outputFileFormat, outputCodec); + long suggestedAdvisoryPartitionSize = (long) (expectedFileSize * shuffleCompressionRatio); + return Math.max(suggestedAdvisoryPartitionSize, sparkAdvisoryPartitionSize()); + } + + private long sparkAdvisoryPartitionSize() { + return (long) spark.sessionState().conf().getConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES()); + } + + private double shuffleCompressionRatio(FileFormat outputFileFormat, String outputCodec) { + return SparkCompressionUtil.shuffleCompressionRatio(spark, outputFileFormat, outputCodec); + } + + public DeleteGranularity deleteGranularity() { + String valueAsString = + confParser + .stringConf() + .option(SparkWriteOptions.DELETE_GRANULARITY) + .tableProperty(TableProperties.DELETE_GRANULARITY) + .defaultValue(TableProperties.DELETE_GRANULARITY_DEFAULT) + .parse(); + return DeleteGranularity.fromString(valueAsString); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java new file mode 100644 index 000000000000..33db70bae587 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +/** Spark DF write options */ +public class SparkWriteOptions { + + private SparkWriteOptions() {} + + // Fileformat for write operations(default: Table write.format.default ) + public static final String WRITE_FORMAT = "write-format"; + + // Overrides this table's write.target-file-size-bytes + public static final String TARGET_FILE_SIZE_BYTES = "target-file-size-bytes"; + + // Overrides the default file format for delete files + public static final String DELETE_FORMAT = "delete-format"; + + // Overrides the default size for delete files + public static final String TARGET_DELETE_FILE_SIZE_BYTES = "target-delete-file-size-bytes"; + + // Sets the nullable check on fields(default: true) + public static final String CHECK_NULLABILITY = "check-nullability"; + + // Adds an entry with custom-key and corresponding value in the snapshot summary + // ex: df.write().format(iceberg) + // .option(SparkWriteOptions.SNAPSHOT_PROPERTY_PREFIX."key1", "value1") + // .save(location) + public static final String SNAPSHOT_PROPERTY_PREFIX = "snapshot-property"; + + // Overrides table property write.spark.fanout.enabled(default: false) + public static final String FANOUT_ENABLED = "fanout-enabled"; + + // Checks if input schema and table schema are same(default: true) + public static final String CHECK_ORDERING = "check-ordering"; + + // File scan task set ID that indicates which files must be replaced + public static final String REWRITTEN_FILE_SCAN_TASK_SET_ID = "rewritten-file-scan-task-set-id"; + + public static final String OUTPUT_SPEC_ID = "output-spec-id"; + + public static final String OVERWRITE_MODE = "overwrite-mode"; + + // Overrides the default distribution mode for a write operation + public static final String DISTRIBUTION_MODE = "distribution-mode"; + + // Controls whether to take into account the table distribution and sort order during a write + // operation + public static final String USE_TABLE_DISTRIBUTION_AND_ORDERING = + "use-table-distribution-and-ordering"; + public static final boolean USE_TABLE_DISTRIBUTION_AND_ORDERING_DEFAULT = true; + + public static final String MERGE_SCHEMA = "merge-schema"; + public static final String SPARK_MERGE_SCHEMA = "mergeSchema"; + + // Identifies snapshot from which to start validating conflicting changes + public static final String VALIDATE_FROM_SNAPSHOT_ID = "validate-from-snapshot-id"; + + // Isolation Level for DataFrame calls. Currently supported by overwritePartitions + public static final String ISOLATION_LEVEL = "isolation-level"; + + // Controls write compress options + public static final String COMPRESSION_CODEC = "compression-codec"; + public static final String COMPRESSION_LEVEL = "compression-level"; + public static final String COMPRESSION_STRATEGY = "compression-strategy"; + + // Overrides the advisory partition size + public static final String ADVISORY_PARTITION_SIZE = "advisory-partition-size"; + + // Overrides the delete granularity + public static final String DELETE_GRANULARITY = "delete-granularity"; +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteRequirements.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteRequirements.java new file mode 100644 index 000000000000..833e0e44e391 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteRequirements.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import org.apache.spark.sql.connector.distributions.Distribution; +import org.apache.spark.sql.connector.distributions.Distributions; +import org.apache.spark.sql.connector.distributions.UnspecifiedDistribution; +import org.apache.spark.sql.connector.expressions.SortOrder; + +/** A set of requirements such as distribution and ordering reported to Spark during writes. */ +public class SparkWriteRequirements { + + public static final SparkWriteRequirements EMPTY = + new SparkWriteRequirements(Distributions.unspecified(), new SortOrder[0], 0); + + private final Distribution distribution; + private final SortOrder[] ordering; + private final long advisoryPartitionSize; + + SparkWriteRequirements( + Distribution distribution, SortOrder[] ordering, long advisoryPartitionSize) { + this.distribution = distribution; + this.ordering = ordering; + this.advisoryPartitionSize = advisoryPartitionSize; + } + + public Distribution distribution() { + return distribution; + } + + public SortOrder[] ordering() { + return ordering; + } + + public boolean hasOrdering() { + return ordering.length != 0; + } + + public long advisoryPartitionSize() { + // Spark prohibits requesting a particular advisory partition size without distribution + return distribution instanceof UnspecifiedDistribution ? 0 : advisoryPartitionSize; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteUtil.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteUtil.java new file mode 100644 index 000000000000..0d68a0d8cdd0 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteUtil.java @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE; +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE; +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPDATE; + +import java.util.Arrays; +import org.apache.iceberg.DistributionMode; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ObjectArrays; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.SortOrderUtil; +import org.apache.spark.sql.connector.distributions.Distribution; +import org.apache.spark.sql.connector.distributions.Distributions; +import org.apache.spark.sql.connector.expressions.Expression; +import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.SortDirection; +import org.apache.spark.sql.connector.expressions.SortOrder; +import org.apache.spark.sql.connector.write.RowLevelOperation.Command; + +/** + * A utility that contains helper methods for working with Spark writes. + * + *

Note it is an evolving internal API that is subject to change even in minor releases. + */ +public class SparkWriteUtil { + + private static final NamedReference SPEC_ID = ref(MetadataColumns.SPEC_ID); + private static final NamedReference PARTITION = ref(MetadataColumns.PARTITION_COLUMN_NAME); + private static final NamedReference FILE_PATH = ref(MetadataColumns.FILE_PATH); + private static final NamedReference ROW_POSITION = ref(MetadataColumns.ROW_POSITION); + + private static final Expression[] FILE_CLUSTERING = clusterBy(FILE_PATH); + private static final Expression[] PARTITION_CLUSTERING = clusterBy(SPEC_ID, PARTITION); + private static final Expression[] PARTITION_FILE_CLUSTERING = + clusterBy(SPEC_ID, PARTITION, FILE_PATH); + + private static final SortOrder[] EMPTY_ORDERING = new SortOrder[0]; + private static final SortOrder[] EXISTING_ROW_ORDERING = orderBy(FILE_PATH, ROW_POSITION); + private static final SortOrder[] PARTITION_ORDERING = orderBy(SPEC_ID, PARTITION); + private static final SortOrder[] PARTITION_FILE_ORDERING = orderBy(SPEC_ID, PARTITION, FILE_PATH); + private static final SortOrder[] POSITION_DELETE_ORDERING = + orderBy(SPEC_ID, PARTITION, FILE_PATH, ROW_POSITION); + + private SparkWriteUtil() {} + + /** Builds requirements for batch and micro-batch writes such as append or overwrite. */ + public static SparkWriteRequirements writeRequirements( + Table table, DistributionMode mode, boolean fanoutEnabled, long advisoryPartitionSize) { + + Distribution distribution = writeDistribution(table, mode); + SortOrder[] ordering = writeOrdering(table, fanoutEnabled); + return new SparkWriteRequirements(distribution, ordering, advisoryPartitionSize); + } + + private static Distribution writeDistribution(Table table, DistributionMode mode) { + switch (mode) { + case NONE: + return Distributions.unspecified(); + + case HASH: + return Distributions.clustered(clustering(table)); + + case RANGE: + return Distributions.ordered(ordering(table)); + + default: + throw new IllegalArgumentException("Unsupported distribution mode: " + mode); + } + } + + /** Builds requirements for copy-on-write DELETE, UPDATE, MERGE operations. */ + public static SparkWriteRequirements copyOnWriteRequirements( + Table table, + Command command, + DistributionMode mode, + boolean fanoutEnabled, + long advisoryPartitionSize) { + + if (command == DELETE || command == UPDATE) { + Distribution distribution = copyOnWriteDeleteUpdateDistribution(table, mode); + SortOrder[] ordering = writeOrdering(table, fanoutEnabled); + return new SparkWriteRequirements(distribution, ordering, advisoryPartitionSize); + } else { + return writeRequirements(table, mode, fanoutEnabled, advisoryPartitionSize); + } + } + + private static Distribution copyOnWriteDeleteUpdateDistribution( + Table table, DistributionMode mode) { + + switch (mode) { + case NONE: + return Distributions.unspecified(); + + case HASH: + if (table.spec().isPartitioned()) { + return Distributions.clustered(clustering(table)); + } else { + return Distributions.clustered(FILE_CLUSTERING); + } + + case RANGE: + if (table.spec().isPartitioned() || table.sortOrder().isSorted()) { + return Distributions.ordered(ordering(table)); + } else { + return Distributions.ordered(EXISTING_ROW_ORDERING); + } + + default: + throw new IllegalArgumentException("Unexpected distribution mode: " + mode); + } + } + + /** Builds requirements for merge-on-read DELETE, UPDATE, MERGE operations. */ + public static SparkWriteRequirements positionDeltaRequirements( + Table table, + Command command, + DistributionMode mode, + boolean fanoutEnabled, + long advisoryPartitionSize) { + + if (command == UPDATE || command == MERGE) { + Distribution distribution = positionDeltaUpdateMergeDistribution(table, mode); + SortOrder[] ordering = positionDeltaUpdateMergeOrdering(table, fanoutEnabled); + return new SparkWriteRequirements(distribution, ordering, advisoryPartitionSize); + } else { + Distribution distribution = positionDeltaDeleteDistribution(table, mode); + SortOrder[] ordering = fanoutEnabled ? EMPTY_ORDERING : POSITION_DELETE_ORDERING; + return new SparkWriteRequirements(distribution, ordering, advisoryPartitionSize); + } + } + + private static Distribution positionDeltaUpdateMergeDistribution( + Table table, DistributionMode mode) { + + switch (mode) { + case NONE: + return Distributions.unspecified(); + + case HASH: + if (table.spec().isUnpartitioned()) { + return Distributions.clustered(concat(PARTITION_FILE_CLUSTERING, clustering(table))); + } else { + return Distributions.clustered(concat(PARTITION_CLUSTERING, clustering(table))); + } + + case RANGE: + if (table.spec().isUnpartitioned()) { + return Distributions.ordered(concat(PARTITION_FILE_ORDERING, ordering(table))); + } else { + return Distributions.ordered(concat(PARTITION_ORDERING, ordering(table))); + } + + default: + throw new IllegalArgumentException("Unsupported distribution mode: " + mode); + } + } + + private static SortOrder[] positionDeltaUpdateMergeOrdering(Table table, boolean fanoutEnabled) { + if (fanoutEnabled && table.sortOrder().isUnsorted()) { + return EMPTY_ORDERING; + } else { + return concat(POSITION_DELETE_ORDERING, ordering(table)); + } + } + + private static Distribution positionDeltaDeleteDistribution(Table table, DistributionMode mode) { + switch (mode) { + case NONE: + return Distributions.unspecified(); + + case HASH: + if (table.spec().isUnpartitioned()) { + return Distributions.clustered(PARTITION_FILE_CLUSTERING); + } else { + return Distributions.clustered(PARTITION_CLUSTERING); + } + + case RANGE: + if (table.spec().isUnpartitioned()) { + return Distributions.ordered(PARTITION_FILE_ORDERING); + } else { + return Distributions.ordered(PARTITION_ORDERING); + } + + default: + throw new IllegalArgumentException("Unsupported distribution mode: " + mode); + } + } + + // a local ordering within a task is beneficial in two cases: + // - there is a defined table sort order, so it is clear how the data should be ordered + // - the table is partitioned and fanout writers are disabled, + // so records for one partition must be co-located within a task + private static SortOrder[] writeOrdering(Table table, boolean fanoutEnabled) { + if (fanoutEnabled && table.sortOrder().isUnsorted()) { + return EMPTY_ORDERING; + } else { + return ordering(table); + } + } + + private static Expression[] clustering(Table table) { + return Spark3Util.toTransforms(table.spec()); + } + + private static SortOrder[] ordering(Table table) { + return Spark3Util.toOrdering(SortOrderUtil.buildSortOrder(table)); + } + + private static Expression[] concat(Expression[] clustering, Expression... otherClustering) { + return ObjectArrays.concat(clustering, otherClustering, Expression.class); + } + + private static SortOrder[] concat(SortOrder[] ordering, SortOrder... otherOrdering) { + return ObjectArrays.concat(ordering, otherOrdering, SortOrder.class); + } + + private static NamedReference ref(Types.NestedField field) { + return Expressions.column(field.name()); + } + + private static NamedReference ref(String name) { + return Expressions.column(name); + } + + private static Expression[] clusterBy(Expression... exprs) { + return exprs; + } + + private static SortOrder[] orderBy(Expression... exprs) { + return Arrays.stream(exprs).map(SparkWriteUtil::sort).toArray(SortOrder[]::new); + } + + private static SortOrder sort(Expression expr) { + return Expressions.sort(expr, SortDirection.ASCENDING); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SupportsFunctions.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SupportsFunctions.java new file mode 100644 index 000000000000..34897d2b4c01 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SupportsFunctions.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import org.apache.iceberg.spark.functions.SparkFunctions; +import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException; +import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; +import org.apache.spark.sql.connector.catalog.FunctionCatalog; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; + +interface SupportsFunctions extends FunctionCatalog { + + default boolean isFunctionNamespace(String[] namespace) { + return namespace.length == 0; + } + + default boolean isExistingNamespace(String[] namespace) { + return namespace.length == 0; + } + + default Identifier[] listFunctions(String[] namespace) throws NoSuchNamespaceException { + if (isFunctionNamespace(namespace)) { + return SparkFunctions.list().stream() + .map(name -> Identifier.of(namespace, name)) + .toArray(Identifier[]::new); + } else if (isExistingNamespace(namespace)) { + return new Identifier[0]; + } + + throw new NoSuchNamespaceException(namespace); + } + + default UnboundFunction loadFunction(Identifier ident) throws NoSuchFunctionException { + String[] namespace = ident.namespace(); + String name = ident.name(); + + if (isFunctionNamespace(namespace)) { + UnboundFunction func = SparkFunctions.load(name); + if (func != null) { + return func; + } + } + + throw new NoSuchFunctionException(ident); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SupportsReplaceView.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SupportsReplaceView.java new file mode 100644 index 000000000000..8bdb7b13861c --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SupportsReplaceView.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Map; +import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; +import org.apache.spark.sql.catalyst.analysis.NoSuchViewException; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.View; +import org.apache.spark.sql.connector.catalog.ViewCatalog; +import org.apache.spark.sql.types.StructType; + +public interface SupportsReplaceView extends ViewCatalog { + /** + * Replace a view in the catalog + * + * @param ident a view identifier + * @param sql the SQL text that defines the view + * @param currentCatalog the current catalog + * @param currentNamespace the current namespace + * @param schema the view query output schema + * @param queryColumnNames the query column names + * @param columnAliases the column aliases + * @param columnComments the column comments + * @param properties the view properties + * @throws NoSuchViewException If the view doesn't exist or is a table + * @throws NoSuchNamespaceException If the identifier namespace does not exist (optional) + */ + View replaceView( + Identifier ident, + String sql, + String currentCatalog, + String[] currentNamespace, + StructType schema, + String[] queryColumnNames, + String[] columnAliases, + String[] columnComments, + Map properties) + throws NoSuchViewException, NoSuchNamespaceException; +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/TypeToSparkType.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/TypeToSparkType.java new file mode 100644 index 000000000000..dfb9b30be603 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/TypeToSparkType.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.List; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Schema; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.types.ArrayType$; +import org.apache.spark.sql.types.BinaryType$; +import org.apache.spark.sql.types.BooleanType$; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateType$; +import org.apache.spark.sql.types.DecimalType$; +import org.apache.spark.sql.types.DoubleType$; +import org.apache.spark.sql.types.FloatType$; +import org.apache.spark.sql.types.IntegerType$; +import org.apache.spark.sql.types.LongType$; +import org.apache.spark.sql.types.MapType$; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.MetadataBuilder; +import org.apache.spark.sql.types.StringType$; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType$; +import org.apache.spark.sql.types.TimestampNTZType$; +import org.apache.spark.sql.types.TimestampType$; + +class TypeToSparkType extends TypeUtil.SchemaVisitor { + TypeToSparkType() {} + + public static final String METADATA_COL_ATTR_KEY = "__metadata_col"; + + @Override + public DataType schema(Schema schema, DataType structType) { + return structType; + } + + @Override + public DataType struct(Types.StructType struct, List fieldResults) { + List fields = struct.fields(); + + List sparkFields = Lists.newArrayListWithExpectedSize(fieldResults.size()); + for (int i = 0; i < fields.size(); i += 1) { + Types.NestedField field = fields.get(i); + DataType type = fieldResults.get(i); + Metadata metadata = fieldMetadata(field.fieldId()); + StructField sparkField = StructField.apply(field.name(), type, field.isOptional(), metadata); + if (field.doc() != null) { + sparkField = sparkField.withComment(field.doc()); + } + sparkFields.add(sparkField); + } + + return StructType$.MODULE$.apply(sparkFields); + } + + @Override + public DataType field(Types.NestedField field, DataType fieldResult) { + return fieldResult; + } + + @Override + public DataType list(Types.ListType list, DataType elementResult) { + return ArrayType$.MODULE$.apply(elementResult, list.isElementOptional()); + } + + @Override + public DataType map(Types.MapType map, DataType keyResult, DataType valueResult) { + return MapType$.MODULE$.apply(keyResult, valueResult, map.isValueOptional()); + } + + @Override + public DataType primitive(Type.PrimitiveType primitive) { + switch (primitive.typeId()) { + case BOOLEAN: + return BooleanType$.MODULE$; + case INTEGER: + return IntegerType$.MODULE$; + case LONG: + return LongType$.MODULE$; + case FLOAT: + return FloatType$.MODULE$; + case DOUBLE: + return DoubleType$.MODULE$; + case DATE: + return DateType$.MODULE$; + case TIME: + throw new UnsupportedOperationException("Spark does not support time fields"); + case TIMESTAMP: + Types.TimestampType ts = (Types.TimestampType) primitive; + if (ts.shouldAdjustToUTC()) { + return TimestampType$.MODULE$; + } else { + return TimestampNTZType$.MODULE$; + } + case STRING: + return StringType$.MODULE$; + case UUID: + // use String + return StringType$.MODULE$; + case FIXED: + return BinaryType$.MODULE$; + case BINARY: + return BinaryType$.MODULE$; + case DECIMAL: + Types.DecimalType decimal = (Types.DecimalType) primitive; + return DecimalType$.MODULE$.apply(decimal.precision(), decimal.scale()); + default: + throw new UnsupportedOperationException( + "Cannot convert unknown type to Spark: " + primitive); + } + } + + private Metadata fieldMetadata(int fieldId) { + if (MetadataColumns.metadataFieldIds().contains(fieldId)) { + return new MetadataBuilder().putBoolean(METADATA_COL_ATTR_KEY, true).build(); + } + + return Metadata.empty(); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/BaseSnapshotUpdateSparkAction.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/BaseSnapshotUpdateSparkAction.java new file mode 100644 index 000000000000..b69b80a8d3a6 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/BaseSnapshotUpdateSparkAction.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.spark.sql.SparkSession; + +abstract class BaseSnapshotUpdateSparkAction extends BaseSparkAction { + + private final Map summary = Maps.newHashMap(); + + protected BaseSnapshotUpdateSparkAction(SparkSession spark) { + super(spark); + } + + public ThisT snapshotProperty(String property, String value) { + summary.put(property, value); + return self(); + } + + protected void commit(org.apache.iceberg.SnapshotUpdate update) { + summary.forEach(update::set); + update.commit(); + } + + protected Map commitSummary() { + return ImmutableMap.copyOf(summary); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/BaseSparkAction.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/BaseSparkAction.java new file mode 100644 index 000000000000..34bb4afe67f9 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/BaseSparkAction.java @@ -0,0 +1,438 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.apache.iceberg.MetadataTableType.ALL_MANIFESTS; +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.lit; + +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Consumer; +import java.util.function.Supplier; +import org.apache.iceberg.AllManifestsTable; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileContent; +import org.apache.iceberg.ManifestContent; +import org.apache.iceberg.ManifestFiles; +import org.apache.iceberg.MetadataTableType; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.ReachableFileUtil; +import org.apache.iceberg.StaticTableOperations; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.exceptions.NotFoundException; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.io.BulkDeletionFailureException; +import org.apache.iceberg.io.CloseableIterator; +import org.apache.iceberg.io.ClosingIterator; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.SupportsBulkOperations; +import org.apache.iceberg.relocated.com.google.common.base.Joiner; +import org.apache.iceberg.relocated.com.google.common.base.Splitter; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Iterators; +import org.apache.iceberg.relocated.com.google.common.collect.ListMultimap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Multimaps; +import org.apache.iceberg.spark.JobGroupInfo; +import org.apache.iceberg.spark.JobGroupUtils; +import org.apache.iceberg.spark.SparkTableUtil; +import org.apache.iceberg.spark.source.SerializableTableWithSize; +import org.apache.iceberg.util.Tasks; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +abstract class BaseSparkAction { + + protected static final String MANIFEST = "Manifest"; + protected static final String MANIFEST_LIST = "Manifest List"; + protected static final String STATISTICS_FILES = "Statistics Files"; + protected static final String OTHERS = "Others"; + + protected static final String FILE_PATH = "file_path"; + protected static final String LAST_MODIFIED = "last_modified"; + + protected static final Splitter COMMA_SPLITTER = Splitter.on(","); + protected static final Joiner COMMA_JOINER = Joiner.on(','); + + private static final Logger LOG = LoggerFactory.getLogger(BaseSparkAction.class); + private static final AtomicInteger JOB_COUNTER = new AtomicInteger(); + private static final int DELETE_NUM_RETRIES = 3; + private static final int DELETE_GROUP_SIZE = 100000; + + private final SparkSession spark; + private final JavaSparkContext sparkContext; + private final Map options = Maps.newHashMap(); + + protected BaseSparkAction(SparkSession spark) { + this.spark = spark; + this.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + } + + protected SparkSession spark() { + return spark; + } + + protected JavaSparkContext sparkContext() { + return sparkContext; + } + + protected abstract ThisT self(); + + public ThisT option(String name, String value) { + options.put(name, value); + return self(); + } + + public ThisT options(Map newOptions) { + options.putAll(newOptions); + return self(); + } + + protected Map options() { + return options; + } + + protected T withJobGroupInfo(JobGroupInfo info, Supplier supplier) { + return JobGroupUtils.withJobGroupInfo(sparkContext, info, supplier); + } + + protected JobGroupInfo newJobGroupInfo(String groupId, String desc) { + return new JobGroupInfo(groupId + "-" + JOB_COUNTER.incrementAndGet(), desc); + } + + protected Table newStaticTable(TableMetadata metadata, FileIO io) { + StaticTableOperations ops = new StaticTableOperations(metadata, io); + return new BaseTable(ops, metadata.metadataFileLocation()); + } + + protected Dataset contentFileDS(Table table) { + return contentFileDS(table, null); + } + + protected Dataset contentFileDS(Table table, Set snapshotIds) { + Table serializableTable = SerializableTableWithSize.copyOf(table); + Broadcast

tableBroadcast = sparkContext.broadcast(serializableTable); + int numShufflePartitions = spark.sessionState().conf().numShufflePartitions(); + + Dataset manifestBeanDS = + manifestDF(table, snapshotIds) + .selectExpr( + "content", + "path", + "length", + "0 as sequenceNumber", + "partition_spec_id as partitionSpecId", + "added_snapshot_id as addedSnapshotId") + .dropDuplicates("path") + .repartition(numShufflePartitions) // avoid adaptive execution combining tasks + .as(ManifestFileBean.ENCODER); + + return manifestBeanDS.flatMap(new ReadManifest(tableBroadcast), FileInfo.ENCODER); + } + + protected Dataset manifestDS(Table table) { + return manifestDS(table, null); + } + + protected Dataset manifestDS(Table table, Set snapshotIds) { + return manifestDF(table, snapshotIds) + .select(col("path"), lit(MANIFEST).as("type")) + .as(FileInfo.ENCODER); + } + + private Dataset manifestDF(Table table, Set snapshotIds) { + Dataset manifestDF = loadMetadataTable(table, ALL_MANIFESTS); + if (snapshotIds != null) { + Column filterCond = col(AllManifestsTable.REF_SNAPSHOT_ID.name()).isInCollection(snapshotIds); + return manifestDF.filter(filterCond); + } else { + return manifestDF; + } + } + + protected Dataset manifestListDS(Table table) { + return manifestListDS(table, null); + } + + protected Dataset manifestListDS(Table table, Set snapshotIds) { + List manifestLists = ReachableFileUtil.manifestListLocations(table, snapshotIds); + return toFileInfoDS(manifestLists, MANIFEST_LIST); + } + + protected Dataset statisticsFileDS(Table table, Set snapshotIds) { + List statisticsFiles = + ReachableFileUtil.statisticsFilesLocationsForSnapshots(table, snapshotIds); + return toFileInfoDS(statisticsFiles, STATISTICS_FILES); + } + + protected Dataset otherMetadataFileDS(Table table) { + return otherMetadataFileDS(table, false /* include all reachable old metadata locations */); + } + + protected Dataset allReachableOtherMetadataFileDS(Table table) { + return otherMetadataFileDS(table, true /* include all reachable old metadata locations */); + } + + private Dataset otherMetadataFileDS(Table table, boolean recursive) { + List otherMetadataFiles = Lists.newArrayList(); + otherMetadataFiles.addAll(ReachableFileUtil.metadataFileLocations(table, recursive)); + otherMetadataFiles.add(ReachableFileUtil.versionHintLocation(table)); + otherMetadataFiles.addAll(ReachableFileUtil.statisticsFilesLocations(table)); + return toFileInfoDS(otherMetadataFiles, OTHERS); + } + + protected Dataset loadMetadataTable(Table table, MetadataTableType type) { + return SparkTableUtil.loadMetadataTable(spark, table, type); + } + + private Dataset toFileInfoDS(List paths, String type) { + List fileInfoList = Lists.transform(paths, path -> new FileInfo(path, type)); + return spark.createDataset(fileInfoList, FileInfo.ENCODER); + } + + /** + * Deletes files and keeps track of how many files were removed for each file type. + * + * @param executorService an executor service to use for parallel deletes + * @param deleteFunc a delete func + * @param files an iterator of Spark rows of the structure (path: String, type: String) + * @return stats on which files were deleted + */ + protected DeleteSummary deleteFiles( + ExecutorService executorService, Consumer deleteFunc, Iterator files) { + + DeleteSummary summary = new DeleteSummary(); + + Tasks.foreach(files) + .retry(DELETE_NUM_RETRIES) + .stopRetryOn(NotFoundException.class) + .suppressFailureWhenFinished() + .executeWith(executorService) + .onFailure( + (fileInfo, exc) -> { + String path = fileInfo.getPath(); + String type = fileInfo.getType(); + LOG.warn("Delete failed for {}: {}", type, path, exc); + }) + .run( + fileInfo -> { + String path = fileInfo.getPath(); + String type = fileInfo.getType(); + deleteFunc.accept(path); + summary.deletedFile(path, type); + }); + + return summary; + } + + protected DeleteSummary deleteFiles(SupportsBulkOperations io, Iterator files) { + DeleteSummary summary = new DeleteSummary(); + Iterator> fileGroups = Iterators.partition(files, DELETE_GROUP_SIZE); + + Tasks.foreach(fileGroups) + .suppressFailureWhenFinished() + .run(fileGroup -> deleteFileGroup(fileGroup, io, summary)); + + return summary; + } + + private static void deleteFileGroup( + List fileGroup, SupportsBulkOperations io, DeleteSummary summary) { + + ListMultimap filesByType = Multimaps.index(fileGroup, FileInfo::getType); + ListMultimap pathsByType = + Multimaps.transformValues(filesByType, FileInfo::getPath); + + for (Map.Entry> entry : pathsByType.asMap().entrySet()) { + String type = entry.getKey(); + Collection paths = entry.getValue(); + int failures = 0; + try { + io.deleteFiles(paths); + } catch (BulkDeletionFailureException e) { + failures = e.numberFailedObjects(); + } + summary.deletedFiles(type, paths.size() - failures); + } + } + + static class DeleteSummary { + private final AtomicLong dataFilesCount = new AtomicLong(0L); + private final AtomicLong positionDeleteFilesCount = new AtomicLong(0L); + private final AtomicLong equalityDeleteFilesCount = new AtomicLong(0L); + private final AtomicLong manifestsCount = new AtomicLong(0L); + private final AtomicLong manifestListsCount = new AtomicLong(0L); + private final AtomicLong statisticsFilesCount = new AtomicLong(0L); + private final AtomicLong otherFilesCount = new AtomicLong(0L); + + public void deletedFiles(String type, int numFiles) { + if (FileContent.DATA.name().equalsIgnoreCase(type)) { + dataFilesCount.addAndGet(numFiles); + + } else if (FileContent.POSITION_DELETES.name().equalsIgnoreCase(type)) { + positionDeleteFilesCount.addAndGet(numFiles); + + } else if (FileContent.EQUALITY_DELETES.name().equalsIgnoreCase(type)) { + equalityDeleteFilesCount.addAndGet(numFiles); + + } else if (MANIFEST.equalsIgnoreCase(type)) { + manifestsCount.addAndGet(numFiles); + + } else if (MANIFEST_LIST.equalsIgnoreCase(type)) { + manifestListsCount.addAndGet(numFiles); + + } else if (STATISTICS_FILES.equalsIgnoreCase(type)) { + statisticsFilesCount.addAndGet(numFiles); + + } else if (OTHERS.equalsIgnoreCase(type)) { + otherFilesCount.addAndGet(numFiles); + + } else { + throw new ValidationException("Illegal file type: %s", type); + } + } + + public void deletedFile(String path, String type) { + if (FileContent.DATA.name().equalsIgnoreCase(type)) { + dataFilesCount.incrementAndGet(); + LOG.trace("Deleted data file: {}", path); + + } else if (FileContent.POSITION_DELETES.name().equalsIgnoreCase(type)) { + positionDeleteFilesCount.incrementAndGet(); + LOG.trace("Deleted positional delete file: {}", path); + + } else if (FileContent.EQUALITY_DELETES.name().equalsIgnoreCase(type)) { + equalityDeleteFilesCount.incrementAndGet(); + LOG.trace("Deleted equality delete file: {}", path); + + } else if (MANIFEST.equalsIgnoreCase(type)) { + manifestsCount.incrementAndGet(); + LOG.debug("Deleted manifest: {}", path); + + } else if (MANIFEST_LIST.equalsIgnoreCase(type)) { + manifestListsCount.incrementAndGet(); + LOG.debug("Deleted manifest list: {}", path); + + } else if (STATISTICS_FILES.equalsIgnoreCase(type)) { + statisticsFilesCount.incrementAndGet(); + LOG.debug("Deleted statistics file: {}", path); + + } else if (OTHERS.equalsIgnoreCase(type)) { + otherFilesCount.incrementAndGet(); + LOG.debug("Deleted other metadata file: {}", path); + + } else { + throw new ValidationException("Illegal file type: %s", type); + } + } + + public long dataFilesCount() { + return dataFilesCount.get(); + } + + public long positionDeleteFilesCount() { + return positionDeleteFilesCount.get(); + } + + public long equalityDeleteFilesCount() { + return equalityDeleteFilesCount.get(); + } + + public long manifestsCount() { + return manifestsCount.get(); + } + + public long manifestListsCount() { + return manifestListsCount.get(); + } + + public long statisticsFilesCount() { + return statisticsFilesCount.get(); + } + + public long otherFilesCount() { + return otherFilesCount.get(); + } + + public long totalFilesCount() { + return dataFilesCount() + + positionDeleteFilesCount() + + equalityDeleteFilesCount() + + manifestsCount() + + manifestListsCount() + + statisticsFilesCount() + + otherFilesCount(); + } + } + + private static class ReadManifest implements FlatMapFunction { + private final Broadcast
table; + + ReadManifest(Broadcast
table) { + this.table = table; + } + + @Override + public Iterator call(ManifestFileBean manifest) { + return new ClosingIterator<>(entries(manifest)); + } + + public CloseableIterator entries(ManifestFileBean manifest) { + ManifestContent content = manifest.content(); + FileIO io = table.getValue().io(); + Map specs = table.getValue().specs(); + List proj = ImmutableList.of(DataFile.FILE_PATH.name(), DataFile.CONTENT.name()); + + switch (content) { + case DATA: + return CloseableIterator.transform( + ManifestFiles.read(manifest, io, specs).select(proj).iterator(), + ReadManifest::toFileInfo); + case DELETES: + return CloseableIterator.transform( + ManifestFiles.readDeleteManifest(manifest, io, specs).select(proj).iterator(), + ReadManifest::toFileInfo); + default: + throw new IllegalArgumentException("Unsupported manifest content type:" + content); + } + } + + static FileInfo toFileInfo(ContentFile file) { + return new FileInfo(file.location(), file.content().toString()); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/BaseTableCreationSparkAction.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/BaseTableCreationSparkAction.java new file mode 100644 index 000000000000..520c520484dc --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/BaseTableCreationSparkAction.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.exceptions.AlreadyExistsException; +import org.apache.iceberg.exceptions.NoSuchNamespaceException; +import org.apache.iceberg.exceptions.NoSuchTableException; +import org.apache.iceberg.mapping.MappingUtil; +import org.apache.iceberg.mapping.NameMapping; +import org.apache.iceberg.mapping.NameMappingParser; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.iceberg.spark.source.StagedSparkTable; +import org.apache.iceberg.util.LocationUtil; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.catalog.CatalogTable; +import org.apache.spark.sql.catalyst.catalog.CatalogUtils; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.StagingTableCatalog; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.catalog.V1Table; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.types.StructType; + +abstract class BaseTableCreationSparkAction extends BaseSparkAction { + private static final Set ALLOWED_SOURCES = + ImmutableSet.of("parquet", "avro", "orc", "hive"); + protected static final String LOCATION = "location"; + protected static final String ICEBERG_METADATA_FOLDER = "metadata"; + protected static final List EXCLUDED_PROPERTIES = + ImmutableList.of("path", "transient_lastDdlTime", "serialization.format"); + + // Source Fields + private final V1Table sourceTable; + private final CatalogTable sourceCatalogTable; + private final String sourceTableLocation; + private final TableCatalog sourceCatalog; + private final Identifier sourceTableIdent; + + // Optional Parameters for destination + private final Map additionalProperties = Maps.newHashMap(); + + BaseTableCreationSparkAction( + SparkSession spark, CatalogPlugin sourceCatalog, Identifier sourceTableIdent) { + super(spark); + + this.sourceCatalog = checkSourceCatalog(sourceCatalog); + this.sourceTableIdent = sourceTableIdent; + + try { + this.sourceTable = (V1Table) this.sourceCatalog.loadTable(sourceTableIdent); + this.sourceCatalogTable = sourceTable.v1Table(); + } catch (org.apache.spark.sql.catalyst.analysis.NoSuchTableException e) { + throw new NoSuchTableException("Cannot not find source table '%s'", sourceTableIdent); + } catch (ClassCastException e) { + throw new IllegalArgumentException( + String.format("Cannot use non-v1 table '%s' as a source", sourceTableIdent), e); + } + validateSourceTable(); + + this.sourceTableLocation = + CatalogUtils.URIToString(sourceCatalogTable.storage().locationUri().get()); + } + + protected abstract TableCatalog checkSourceCatalog(CatalogPlugin catalog); + + protected abstract StagingTableCatalog destCatalog(); + + protected abstract Identifier destTableIdent(); + + protected abstract Map destTableProps(); + + protected String sourceTableLocation() { + return sourceTableLocation; + } + + protected CatalogTable v1SourceTable() { + return sourceCatalogTable; + } + + protected TableCatalog sourceCatalog() { + return sourceCatalog; + } + + protected Identifier sourceTableIdent() { + return sourceTableIdent; + } + + protected void setProperties(Map properties) { + additionalProperties.putAll(properties); + } + + protected void setProperty(String key, String value) { + additionalProperties.put(key, value); + } + + protected Map additionalProperties() { + return additionalProperties; + } + + private void validateSourceTable() { + String sourceTableProvider = sourceCatalogTable.provider().get().toLowerCase(Locale.ROOT); + Preconditions.checkArgument( + ALLOWED_SOURCES.contains(sourceTableProvider), + "Cannot create an Iceberg table from source provider: '%s'", + sourceTableProvider); + Preconditions.checkArgument( + !sourceCatalogTable.storage().locationUri().isEmpty(), + "Cannot create an Iceberg table from a source without an explicit location"); + } + + protected StagingTableCatalog checkDestinationCatalog(CatalogPlugin catalog) { + Preconditions.checkArgument( + catalog instanceof SparkSessionCatalog || catalog instanceof SparkCatalog, + "Cannot create Iceberg table in non-Iceberg Catalog. " + + "Catalog '%s' was of class '%s' but '%s' or '%s' are required", + catalog.name(), + catalog.getClass().getName(), + SparkSessionCatalog.class.getName(), + SparkCatalog.class.getName()); + + return (StagingTableCatalog) catalog; + } + + protected StagedSparkTable stageDestTable() { + try { + Map props = destTableProps(); + StructType schema = sourceTable.schema(); + Transform[] partitioning = sourceTable.partitioning(); + return (StagedSparkTable) + destCatalog().stageCreate(destTableIdent(), schema, partitioning, props); + } catch (org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException e) { + throw new NoSuchNamespaceException( + "Cannot create table %s as the namespace does not exist", destTableIdent()); + } catch (org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException e) { + throw new AlreadyExistsException( + "Cannot create table %s as it already exists", destTableIdent()); + } + } + + protected void ensureNameMappingPresent(Table table) { + if (!table.properties().containsKey(TableProperties.DEFAULT_NAME_MAPPING)) { + NameMapping nameMapping = MappingUtil.create(table.schema()); + String nameMappingJson = NameMappingParser.toJson(nameMapping); + table.updateProperties().set(TableProperties.DEFAULT_NAME_MAPPING, nameMappingJson).commit(); + } + } + + protected String getMetadataLocation(Table table) { + String defaultValue = + LocationUtil.stripTrailingSlash(table.location()) + "/" + ICEBERG_METADATA_FOLDER; + return LocationUtil.stripTrailingSlash( + table.properties().getOrDefault(TableProperties.WRITE_METADATA_LOCATION, defaultValue)); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/ComputeTableStatsSparkAction.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/ComputeTableStatsSparkAction.java new file mode 100644 index 000000000000..a508021c1040 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/ComputeTableStatsSparkAction.java @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.io.IOException; +import java.util.List; +import java.util.UUID; +import java.util.stream.Collectors; +import org.apache.iceberg.GenericBlobMetadata; +import org.apache.iceberg.GenericStatisticsFile; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.IcebergBuild; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.StatisticsFile; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.actions.ComputeTableStats; +import org.apache.iceberg.actions.ImmutableComputeTableStats; +import org.apache.iceberg.exceptions.RuntimeIOException; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.puffin.Blob; +import org.apache.iceberg.puffin.Puffin; +import org.apache.iceberg.puffin.PuffinWriter; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.spark.JobGroupInfo; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.SparkSession; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Computes the statistics of the given columns and stores it as Puffin files. */ +public class ComputeTableStatsSparkAction extends BaseSparkAction + implements ComputeTableStats { + + private static final Logger LOG = LoggerFactory.getLogger(ComputeTableStatsSparkAction.class); + private static final Result EMPTY_RESULT = ImmutableComputeTableStats.Result.builder().build(); + + private final Table table; + private List columns; + private Snapshot snapshot; + + ComputeTableStatsSparkAction(SparkSession spark, Table table) { + super(spark); + this.table = table; + this.snapshot = table.currentSnapshot(); + } + + @Override + protected ComputeTableStatsSparkAction self() { + return this; + } + + @Override + public ComputeTableStats columns(String... newColumns) { + Preconditions.checkArgument( + newColumns != null && newColumns.length > 0, "Columns cannot be null/empty"); + this.columns = ImmutableList.copyOf(ImmutableSet.copyOf(newColumns)); + return this; + } + + @Override + public ComputeTableStats snapshot(long newSnapshotId) { + Snapshot newSnapshot = table.snapshot(newSnapshotId); + Preconditions.checkArgument(newSnapshot != null, "Snapshot not found: %s", newSnapshotId); + this.snapshot = newSnapshot; + return this; + } + + @Override + public Result execute() { + if (snapshot == null) { + LOG.info("No snapshot to compute stats for table {}", table.name()); + return EMPTY_RESULT; + } + validateColumns(); + JobGroupInfo info = newJobGroupInfo("COMPUTE-TABLE-STATS", jobDesc()); + return withJobGroupInfo(info, this::doExecute); + } + + private Result doExecute() { + LOG.info( + "Computing stats for columns {} in {} (snapshot {})", + columns(), + table.name(), + snapshotId()); + List blobs = generateNDVBlobs(); + StatisticsFile statisticsFile = writeStatsFile(blobs); + table.updateStatistics().setStatistics(snapshotId(), statisticsFile).commit(); + return ImmutableComputeTableStats.Result.builder().statisticsFile(statisticsFile).build(); + } + + private StatisticsFile writeStatsFile(List blobs) { + LOG.info("Writing stats for table {} for snapshot {}", table.name(), snapshotId()); + OutputFile outputFile = table.io().newOutputFile(outputPath()); + try (PuffinWriter writer = Puffin.write(outputFile).createdBy(appIdentifier()).build()) { + blobs.forEach(writer::add); + writer.finish(); + return new GenericStatisticsFile( + snapshotId(), + outputFile.location(), + writer.fileSize(), + writer.footerSize(), + GenericBlobMetadata.from(writer.writtenBlobsMetadata())); + } catch (IOException e) { + throw new RuntimeIOException(e); + } + } + + private List generateNDVBlobs() { + return NDVSketchUtil.generateBlobs(spark(), table, snapshot, columns()); + } + + private List columns() { + if (columns == null) { + Schema schema = table.schemas().get(snapshot.schemaId()); + this.columns = + schema.columns().stream() + .filter(nestedField -> nestedField.type().isPrimitiveType()) + .map(Types.NestedField::name) + .collect(Collectors.toList()); + } + return columns; + } + + private void validateColumns() { + Schema schema = table.schemas().get(snapshot.schemaId()); + Preconditions.checkArgument(!columns().isEmpty(), "No columns found to compute stats"); + for (String columnName : columns()) { + Types.NestedField field = schema.findField(columnName); + Preconditions.checkArgument(field != null, "Can't find column %s in %s", columnName, schema); + Preconditions.checkArgument( + field.type().isPrimitiveType(), + "Can't compute stats on non-primitive type column: %s (%s)", + columnName, + field.type()); + } + } + + private String appIdentifier() { + String icebergVersion = IcebergBuild.fullVersion(); + String sparkVersion = spark().version(); + return String.format("Iceberg %s Spark %s", icebergVersion, sparkVersion); + } + + private long snapshotId() { + return snapshot.snapshotId(); + } + + private String jobDesc() { + return String.format( + "Computing table stats for %s (snapshot_id=%s, columns=%s)", + table.name(), snapshotId(), columns()); + } + + private String outputPath() { + TableOperations operations = ((HasTableOperations) table).operations(); + String fileName = String.format("%s-%s.stats", snapshotId(), UUID.randomUUID()); + return operations.metadataFileLocation(fileName); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/DeleteOrphanFilesSparkAction.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/DeleteOrphanFilesSparkAction.java new file mode 100644 index 000000000000..5fbb4117feb8 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/DeleteOrphanFilesSparkAction.java @@ -0,0 +1,676 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.apache.iceberg.TableProperties.GC_ENABLED; +import static org.apache.iceberg.TableProperties.GC_ENABLED_DEFAULT; + +import java.io.IOException; +import java.io.Serializable; +import java.io.UncheckedIOException; +import java.net.URI; +import java.sql.Timestamp; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.fs.PathFilter; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.DeleteOrphanFiles; +import org.apache.iceberg.actions.ImmutableDeleteOrphanFiles; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.hadoop.HiddenPathFilter; +import org.apache.iceberg.io.BulkDeletionFailureException; +import org.apache.iceberg.io.SupportsBulkOperations; +import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.base.Strings; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterators; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.JobGroupInfo; +import org.apache.iceberg.util.Pair; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.Tasks; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.MapPartitionsFunction; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.util.SerializableConfiguration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.Tuple2; + +/** + * An action that removes orphan metadata, data and delete files by listing a given location and + * comparing the actual files in that location with content and metadata files referenced by all + * valid snapshots. The location must be accessible for listing via the Hadoop {@link FileSystem}. + * + *

By default, this action cleans up the table location returned by {@link Table#location()} and + * removes unreachable files that are older than 3 days using {@link Table#io()}. The behavior can + * be modified by passing a custom location to {@link #location} and a custom timestamp to {@link + * #olderThan(long)}. For example, someone might point this action to the data folder to clean up + * only orphan data files. + * + *

Configure an alternative delete method using {@link #deleteWith(Consumer)}. + * + *

For full control of the set of files being evaluated, use the {@link + * #compareToFileList(Dataset)} argument. This skips the directory listing - any files in the + * dataset provided which are not found in table metadata will be deleted, using the same {@link + * Table#location()} and {@link #olderThan(long)} filtering as above. + * + *

Note: It is dangerous to call this action with a short retention interval as it might + * corrupt the state of the table if another operation is writing at the same time. + */ +public class DeleteOrphanFilesSparkAction extends BaseSparkAction + implements DeleteOrphanFiles { + + private static final Logger LOG = LoggerFactory.getLogger(DeleteOrphanFilesSparkAction.class); + private static final Map EQUAL_SCHEMES_DEFAULT = ImmutableMap.of("s3n,s3a", "s3"); + private static final int MAX_DRIVER_LISTING_DEPTH = 3; + private static final int MAX_DRIVER_LISTING_DIRECT_SUB_DIRS = 10; + private static final int MAX_EXECUTOR_LISTING_DEPTH = 2000; + private static final int MAX_EXECUTOR_LISTING_DIRECT_SUB_DIRS = Integer.MAX_VALUE; + + private final SerializableConfiguration hadoopConf; + private final int listingParallelism; + private final Table table; + private Map equalSchemes = flattenMap(EQUAL_SCHEMES_DEFAULT); + private Map equalAuthorities = Collections.emptyMap(); + private PrefixMismatchMode prefixMismatchMode = PrefixMismatchMode.ERROR; + private String location; + private long olderThanTimestamp = System.currentTimeMillis() - TimeUnit.DAYS.toMillis(3); + private Dataset compareToFileList; + private Consumer deleteFunc = null; + private ExecutorService deleteExecutorService = null; + + DeleteOrphanFilesSparkAction(SparkSession spark, Table table) { + super(spark); + + this.hadoopConf = new SerializableConfiguration(spark.sessionState().newHadoopConf()); + this.listingParallelism = spark.sessionState().conf().parallelPartitionDiscoveryParallelism(); + this.table = table; + this.location = table.location(); + + ValidationException.check( + PropertyUtil.propertyAsBoolean(table.properties(), GC_ENABLED, GC_ENABLED_DEFAULT), + "Cannot delete orphan files: GC is disabled (deleting files may corrupt other tables)"); + } + + @Override + protected DeleteOrphanFilesSparkAction self() { + return this; + } + + @Override + public DeleteOrphanFilesSparkAction executeDeleteWith(ExecutorService executorService) { + this.deleteExecutorService = executorService; + return this; + } + + @Override + public DeleteOrphanFilesSparkAction prefixMismatchMode(PrefixMismatchMode newPrefixMismatchMode) { + this.prefixMismatchMode = newPrefixMismatchMode; + return this; + } + + @Override + public DeleteOrphanFilesSparkAction equalSchemes(Map newEqualSchemes) { + this.equalSchemes = Maps.newHashMap(); + equalSchemes.putAll(flattenMap(EQUAL_SCHEMES_DEFAULT)); + equalSchemes.putAll(flattenMap(newEqualSchemes)); + return this; + } + + @Override + public DeleteOrphanFilesSparkAction equalAuthorities(Map newEqualAuthorities) { + this.equalAuthorities = Maps.newHashMap(); + equalAuthorities.putAll(flattenMap(newEqualAuthorities)); + return this; + } + + @Override + public DeleteOrphanFilesSparkAction location(String newLocation) { + this.location = newLocation; + return this; + } + + @Override + public DeleteOrphanFilesSparkAction olderThan(long newOlderThanTimestamp) { + this.olderThanTimestamp = newOlderThanTimestamp; + return this; + } + + @Override + public DeleteOrphanFilesSparkAction deleteWith(Consumer newDeleteFunc) { + this.deleteFunc = newDeleteFunc; + return this; + } + + public DeleteOrphanFilesSparkAction compareToFileList(Dataset files) { + StructType schema = files.schema(); + + StructField filePathField = schema.apply(FILE_PATH); + Preconditions.checkArgument( + filePathField.dataType() == DataTypes.StringType, + "Invalid %s column: %s is not a string", + FILE_PATH, + filePathField.dataType()); + + StructField lastModifiedField = schema.apply(LAST_MODIFIED); + Preconditions.checkArgument( + lastModifiedField.dataType() == DataTypes.TimestampType, + "Invalid %s column: %s is not a timestamp", + LAST_MODIFIED, + lastModifiedField.dataType()); + + this.compareToFileList = files; + return this; + } + + private Dataset filteredCompareToFileList() { + Dataset files = compareToFileList; + if (location != null) { + files = files.filter(files.col(FILE_PATH).startsWith(location)); + } + return files + .filter(files.col(LAST_MODIFIED).lt(new Timestamp(olderThanTimestamp))) + .select(files.col(FILE_PATH)) + .as(Encoders.STRING()); + } + + @Override + public DeleteOrphanFiles.Result execute() { + JobGroupInfo info = newJobGroupInfo("DELETE-ORPHAN-FILES", jobDesc()); + return withJobGroupInfo(info, this::doExecute); + } + + private String jobDesc() { + List options = Lists.newArrayList(); + options.add("older_than=" + olderThanTimestamp); + if (location != null) { + options.add("location=" + location); + } + String optionsAsString = COMMA_JOINER.join(options); + return String.format("Deleting orphan files (%s) from %s", optionsAsString, table.name()); + } + + private void deleteFiles(SupportsBulkOperations io, List paths) { + try { + io.deleteFiles(paths); + LOG.info("Deleted {} files using bulk deletes", paths.size()); + } catch (BulkDeletionFailureException e) { + int deletedFilesCount = paths.size() - e.numberFailedObjects(); + LOG.warn("Deleted only {} of {} files using bulk deletes", deletedFilesCount, paths.size()); + } + } + + private DeleteOrphanFiles.Result doExecute() { + Dataset actualFileIdentDS = actualFileIdentDS(); + Dataset validFileIdentDS = validFileIdentDS(); + + List orphanFiles = + findOrphanFiles(spark(), actualFileIdentDS, validFileIdentDS, prefixMismatchMode); + + if (deleteFunc == null && table.io() instanceof SupportsBulkOperations) { + deleteFiles((SupportsBulkOperations) table.io(), orphanFiles); + } else { + + Tasks.Builder deleteTasks = + Tasks.foreach(orphanFiles) + .noRetry() + .executeWith(deleteExecutorService) + .suppressFailureWhenFinished() + .onFailure((file, exc) -> LOG.warn("Failed to delete file: {}", file, exc)); + + if (deleteFunc == null) { + LOG.info( + "Table IO {} does not support bulk operations. Using non-bulk deletes.", + table.io().getClass().getName()); + deleteTasks.run(table.io()::deleteFile); + } else { + LOG.info("Custom delete function provided. Using non-bulk deletes"); + deleteTasks.run(deleteFunc::accept); + } + } + + return ImmutableDeleteOrphanFiles.Result.builder().orphanFileLocations(orphanFiles).build(); + } + + private Dataset validFileIdentDS() { + // transform before union to avoid extra serialization/deserialization + FileInfoToFileURI toFileURI = new FileInfoToFileURI(equalSchemes, equalAuthorities); + + Dataset contentFileIdentDS = toFileURI.apply(contentFileDS(table)); + Dataset manifestFileIdentDS = toFileURI.apply(manifestDS(table)); + Dataset manifestListIdentDS = toFileURI.apply(manifestListDS(table)); + Dataset otherMetadataFileIdentDS = toFileURI.apply(otherMetadataFileDS(table)); + + return contentFileIdentDS + .union(manifestFileIdentDS) + .union(manifestListIdentDS) + .union(otherMetadataFileIdentDS); + } + + private Dataset actualFileIdentDS() { + StringToFileURI toFileURI = new StringToFileURI(equalSchemes, equalAuthorities); + if (compareToFileList == null) { + return toFileURI.apply(listedFileDS()); + } else { + return toFileURI.apply(filteredCompareToFileList()); + } + } + + private Dataset listedFileDS() { + List subDirs = Lists.newArrayList(); + List matchingFiles = Lists.newArrayList(); + + Predicate predicate = file -> file.getModificationTime() < olderThanTimestamp; + PathFilter pathFilter = PartitionAwareHiddenPathFilter.forSpecs(table.specs()); + + // list at most MAX_DRIVER_LISTING_DEPTH levels and only dirs that have + // less than MAX_DRIVER_LISTING_DIRECT_SUB_DIRS direct sub dirs on the driver + listDirRecursively( + location, + predicate, + hadoopConf.value(), + MAX_DRIVER_LISTING_DEPTH, + MAX_DRIVER_LISTING_DIRECT_SUB_DIRS, + subDirs, + pathFilter, + matchingFiles); + + JavaRDD matchingFileRDD = sparkContext().parallelize(matchingFiles, 1); + + if (subDirs.isEmpty()) { + return spark().createDataset(matchingFileRDD.rdd(), Encoders.STRING()); + } + + int parallelism = Math.min(subDirs.size(), listingParallelism); + JavaRDD subDirRDD = sparkContext().parallelize(subDirs, parallelism); + + Broadcast conf = sparkContext().broadcast(hadoopConf); + ListDirsRecursively listDirs = new ListDirsRecursively(conf, olderThanTimestamp, pathFilter); + JavaRDD matchingLeafFileRDD = subDirRDD.mapPartitions(listDirs); + + JavaRDD completeMatchingFileRDD = matchingFileRDD.union(matchingLeafFileRDD); + return spark().createDataset(completeMatchingFileRDD.rdd(), Encoders.STRING()); + } + + private static void listDirRecursively( + String dir, + Predicate predicate, + Configuration conf, + int maxDepth, + int maxDirectSubDirs, + List remainingSubDirs, + PathFilter pathFilter, + List matchingFiles) { + + // stop listing whenever we reach the max depth + if (maxDepth <= 0) { + remainingSubDirs.add(dir); + return; + } + + try { + Path path = new Path(dir); + FileSystem fs = path.getFileSystem(conf); + + List subDirs = Lists.newArrayList(); + + for (FileStatus file : fs.listStatus(path, pathFilter)) { + if (file.isDirectory()) { + subDirs.add(file.getPath().toString()); + } else if (file.isFile() && predicate.test(file)) { + matchingFiles.add(file.getPath().toString()); + } + } + + // stop listing if the number of direct sub dirs is bigger than maxDirectSubDirs + if (subDirs.size() > maxDirectSubDirs) { + remainingSubDirs.addAll(subDirs); + return; + } + + for (String subDir : subDirs) { + listDirRecursively( + subDir, + predicate, + conf, + maxDepth - 1, + maxDirectSubDirs, + remainingSubDirs, + pathFilter, + matchingFiles); + } + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @VisibleForTesting + static List findOrphanFiles( + SparkSession spark, + Dataset actualFileIdentDS, + Dataset validFileIdentDS, + PrefixMismatchMode prefixMismatchMode) { + + SetAccumulator> conflicts = new SetAccumulator<>(); + spark.sparkContext().register(conflicts); + + Column joinCond = actualFileIdentDS.col("path").equalTo(validFileIdentDS.col("path")); + + List orphanFiles = + actualFileIdentDS + .joinWith(validFileIdentDS, joinCond, "leftouter") + .mapPartitions(new FindOrphanFiles(prefixMismatchMode, conflicts), Encoders.STRING()) + .collectAsList(); + + if (prefixMismatchMode == PrefixMismatchMode.ERROR && !conflicts.value().isEmpty()) { + throw new ValidationException( + "Unable to determine whether certain files are orphan. " + + "Metadata references files that match listed/provided files except for authority/scheme. " + + "Please, inspect the conflicting authorities/schemes and provide which of them are equal " + + "by further configuring the action via equalSchemes() and equalAuthorities() methods. " + + "Set the prefix mismatch mode to 'NONE' to ignore remaining locations with conflicting " + + "authorities/schemes or to 'DELETE' iff you are ABSOLUTELY confident that remaining conflicting " + + "authorities/schemes are different. It will be impossible to recover deleted files. " + + "Conflicting authorities/schemes: %s.", + conflicts.value()); + } + + return orphanFiles; + } + + private static Map flattenMap(Map map) { + Map flattenedMap = Maps.newHashMap(); + if (map != null) { + for (String key : map.keySet()) { + String value = map.get(key); + for (String splitKey : COMMA_SPLITTER.split(key)) { + flattenedMap.put(splitKey.trim(), value.trim()); + } + } + } + return flattenedMap; + } + + private static class ListDirsRecursively implements FlatMapFunction, String> { + + private final Broadcast hadoopConf; + private final long olderThanTimestamp; + private final PathFilter pathFilter; + + ListDirsRecursively( + Broadcast hadoopConf, + long olderThanTimestamp, + PathFilter pathFilter) { + + this.hadoopConf = hadoopConf; + this.olderThanTimestamp = olderThanTimestamp; + this.pathFilter = pathFilter; + } + + @Override + public Iterator call(Iterator dirs) throws Exception { + List subDirs = Lists.newArrayList(); + List files = Lists.newArrayList(); + + Predicate predicate = file -> file.getModificationTime() < olderThanTimestamp; + + while (dirs.hasNext()) { + listDirRecursively( + dirs.next(), + predicate, + hadoopConf.value().value(), + MAX_EXECUTOR_LISTING_DEPTH, + MAX_EXECUTOR_LISTING_DIRECT_SUB_DIRS, + subDirs, + pathFilter, + files); + } + + if (!subDirs.isEmpty()) { + throw new RuntimeException( + "Could not list sub directories, reached maximum depth: " + MAX_EXECUTOR_LISTING_DEPTH); + } + + return files.iterator(); + } + } + + private static class FindOrphanFiles + implements MapPartitionsFunction, String> { + + private final PrefixMismatchMode mode; + private final SetAccumulator> conflicts; + + FindOrphanFiles(PrefixMismatchMode mode, SetAccumulator> conflicts) { + this.mode = mode; + this.conflicts = conflicts; + } + + @Override + public Iterator call(Iterator> rows) throws Exception { + Iterator orphanFiles = Iterators.transform(rows, this::toOrphanFile); + return Iterators.filter(orphanFiles, Objects::nonNull); + } + + private String toOrphanFile(Tuple2 row) { + FileURI actual = row._1; + FileURI valid = row._2; + + if (valid == null) { + return actual.uriAsString; + } + + boolean schemeMatch = uriComponentMatch(valid.scheme, actual.scheme); + boolean authorityMatch = uriComponentMatch(valid.authority, actual.authority); + + if ((!schemeMatch || !authorityMatch) && mode == PrefixMismatchMode.DELETE) { + return actual.uriAsString; + } else { + if (!schemeMatch) { + conflicts.add(Pair.of(valid.scheme, actual.scheme)); + } + + if (!authorityMatch) { + conflicts.add(Pair.of(valid.authority, actual.authority)); + } + + return null; + } + } + + private boolean uriComponentMatch(String valid, String actual) { + return Strings.isNullOrEmpty(valid) || valid.equalsIgnoreCase(actual); + } + } + + @VisibleForTesting + static class StringToFileURI extends ToFileURI { + StringToFileURI(Map equalSchemes, Map equalAuthorities) { + super(equalSchemes, equalAuthorities); + } + + @Override + protected String uriAsString(String input) { + return input; + } + } + + @VisibleForTesting + static class FileInfoToFileURI extends ToFileURI { + FileInfoToFileURI(Map equalSchemes, Map equalAuthorities) { + super(equalSchemes, equalAuthorities); + } + + @Override + protected String uriAsString(FileInfo fileInfo) { + return fileInfo.getPath(); + } + } + + private abstract static class ToFileURI implements MapPartitionsFunction { + + private final Map equalSchemes; + private final Map equalAuthorities; + + ToFileURI(Map equalSchemes, Map equalAuthorities) { + this.equalSchemes = equalSchemes; + this.equalAuthorities = equalAuthorities; + } + + protected abstract String uriAsString(I input); + + Dataset apply(Dataset ds) { + return ds.mapPartitions(this, FileURI.ENCODER); + } + + @Override + public Iterator call(Iterator rows) throws Exception { + return Iterators.transform(rows, this::toFileURI); + } + + private FileURI toFileURI(I input) { + String uriAsString = uriAsString(input); + URI uri = new Path(uriAsString).toUri(); + String scheme = equalSchemes.getOrDefault(uri.getScheme(), uri.getScheme()); + String authority = equalAuthorities.getOrDefault(uri.getAuthority(), uri.getAuthority()); + return new FileURI(scheme, authority, uri.getPath(), uriAsString); + } + } + + /** + * A {@link PathFilter} that filters out hidden path, but does not filter out paths that would be + * marked as hidden by {@link HiddenPathFilter} due to a partition field that starts with one of + * the characters that indicate a hidden path. + */ + @VisibleForTesting + static class PartitionAwareHiddenPathFilter implements PathFilter, Serializable { + + private final Set hiddenPathPartitionNames; + + PartitionAwareHiddenPathFilter(Set hiddenPathPartitionNames) { + this.hiddenPathPartitionNames = hiddenPathPartitionNames; + } + + @Override + public boolean accept(Path path) { + return isHiddenPartitionPath(path) || HiddenPathFilter.get().accept(path); + } + + private boolean isHiddenPartitionPath(Path path) { + return hiddenPathPartitionNames.stream().anyMatch(path.getName()::startsWith); + } + + static PathFilter forSpecs(Map specs) { + if (specs == null) { + return HiddenPathFilter.get(); + } + + Set partitionNames = + specs.values().stream() + .map(PartitionSpec::fields) + .flatMap(List::stream) + .filter(field -> field.name().startsWith("_") || field.name().startsWith(".")) + .map(field -> field.name() + "=") + .collect(Collectors.toSet()); + + if (partitionNames.isEmpty()) { + return HiddenPathFilter.get(); + } else { + return new PartitionAwareHiddenPathFilter(partitionNames); + } + } + } + + public static class FileURI { + public static final Encoder ENCODER = Encoders.bean(FileURI.class); + + private String scheme; + private String authority; + private String path; + private String uriAsString; + + public FileURI(String scheme, String authority, String path, String uriAsString) { + this.scheme = scheme; + this.authority = authority; + this.path = path; + this.uriAsString = uriAsString; + } + + public FileURI() {} + + public void setScheme(String scheme) { + this.scheme = scheme; + } + + public void setAuthority(String authority) { + this.authority = authority; + } + + public void setPath(String path) { + this.path = path; + } + + public void setUriAsString(String uriAsString) { + this.uriAsString = uriAsString; + } + + public String getScheme() { + return scheme; + } + + public String getAuthority() { + return authority; + } + + public String getPath() { + return path; + } + + public String getUriAsString() { + return uriAsString; + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/DeleteReachableFilesSparkAction.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/DeleteReachableFilesSparkAction.java new file mode 100644 index 000000000000..ea6ac9f3dbf5 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/DeleteReachableFilesSparkAction.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.apache.iceberg.TableProperties.GC_ENABLED; +import static org.apache.iceberg.TableProperties.GC_ENABLED_DEFAULT; + +import java.util.Iterator; +import java.util.concurrent.ExecutorService; +import java.util.function.Consumer; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.TableMetadataParser; +import org.apache.iceberg.actions.DeleteReachableFiles; +import org.apache.iceberg.actions.ImmutableDeleteReachableFiles; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.hadoop.HadoopFileIO; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.SupportsBulkOperations; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.JobGroupInfo; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.SparkSession; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * An implementation of {@link DeleteReachableFiles} that uses metadata tables in Spark to determine + * which files should be deleted. + */ +@SuppressWarnings("UnnecessaryAnonymousClass") +public class DeleteReachableFilesSparkAction + extends BaseSparkAction implements DeleteReachableFiles { + + public static final String STREAM_RESULTS = "stream-results"; + public static final boolean STREAM_RESULTS_DEFAULT = false; + + private static final Logger LOG = LoggerFactory.getLogger(DeleteReachableFilesSparkAction.class); + + private final String metadataFileLocation; + + private Consumer deleteFunc = null; + private ExecutorService deleteExecutorService = null; + private FileIO io = new HadoopFileIO(spark().sessionState().newHadoopConf()); + + DeleteReachableFilesSparkAction(SparkSession spark, String metadataFileLocation) { + super(spark); + this.metadataFileLocation = metadataFileLocation; + } + + @Override + protected DeleteReachableFilesSparkAction self() { + return this; + } + + @Override + public DeleteReachableFilesSparkAction io(FileIO fileIO) { + this.io = fileIO; + return this; + } + + @Override + public DeleteReachableFilesSparkAction deleteWith(Consumer newDeleteFunc) { + this.deleteFunc = newDeleteFunc; + return this; + } + + @Override + public DeleteReachableFilesSparkAction executeDeleteWith(ExecutorService executorService) { + this.deleteExecutorService = executorService; + return this; + } + + @Override + public Result execute() { + Preconditions.checkArgument(io != null, "File IO cannot be null"); + String jobDesc = String.format("Deleting files reachable from %s", metadataFileLocation); + JobGroupInfo info = newJobGroupInfo("DELETE-REACHABLE-FILES", jobDesc); + return withJobGroupInfo(info, this::doExecute); + } + + private Result doExecute() { + TableMetadata metadata = TableMetadataParser.read(io, metadataFileLocation); + + ValidationException.check( + PropertyUtil.propertyAsBoolean(metadata.properties(), GC_ENABLED, GC_ENABLED_DEFAULT), + "Cannot delete files: GC is disabled (deleting files may corrupt other tables)"); + + Dataset reachableFileDS = reachableFileDS(metadata); + + if (streamResults()) { + return deleteFiles(reachableFileDS.toLocalIterator()); + } else { + return deleteFiles(reachableFileDS.collectAsList().iterator()); + } + } + + private boolean streamResults() { + return PropertyUtil.propertyAsBoolean(options(), STREAM_RESULTS, STREAM_RESULTS_DEFAULT); + } + + private Dataset reachableFileDS(TableMetadata metadata) { + Table staticTable = newStaticTable(metadata, io); + return contentFileDS(staticTable) + .union(manifestDS(staticTable)) + .union(manifestListDS(staticTable)) + .union(allReachableOtherMetadataFileDS(staticTable)) + .distinct(); + } + + private DeleteReachableFiles.Result deleteFiles(Iterator files) { + DeleteSummary summary; + if (deleteFunc == null && io instanceof SupportsBulkOperations) { + summary = deleteFiles((SupportsBulkOperations) io, files); + } else { + + if (deleteFunc == null) { + LOG.info( + "Table IO {} does not support bulk operations. Using non-bulk deletes.", + io.getClass().getName()); + summary = deleteFiles(deleteExecutorService, io::deleteFile, files); + } else { + LOG.info("Custom delete function provided. Using non-bulk deletes"); + summary = deleteFiles(deleteExecutorService, deleteFunc, files); + } + } + + LOG.info("Deleted {} total files", summary.totalFilesCount()); + + return ImmutableDeleteReachableFiles.Result.builder() + .deletedDataFilesCount(summary.dataFilesCount()) + .deletedPositionDeleteFilesCount(summary.positionDeleteFilesCount()) + .deletedEqualityDeleteFilesCount(summary.equalityDeleteFilesCount()) + .deletedManifestsCount(summary.manifestsCount()) + .deletedManifestListsCount(summary.manifestListsCount()) + .deletedOtherFilesCount(summary.otherFilesCount()) + .build(); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/ExpireSnapshotsSparkAction.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/ExpireSnapshotsSparkAction.java new file mode 100644 index 000000000000..2468497e42d0 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/ExpireSnapshotsSparkAction.java @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.apache.iceberg.TableProperties.GC_ENABLED; +import static org.apache.iceberg.TableProperties.GC_ENABLED_DEFAULT; + +import java.util.Iterator; +import java.util.List; +import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.actions.ExpireSnapshots; +import org.apache.iceberg.actions.ImmutableExpireSnapshots; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.io.SupportsBulkOperations; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.JobGroupInfo; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.SparkSession; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * An action that performs the same operation as {@link org.apache.iceberg.ExpireSnapshots} but uses + * Spark to determine the delta in files between the pre and post-expiration table metadata. All of + * the same restrictions of {@link org.apache.iceberg.ExpireSnapshots} also apply to this action. + * + *

This action first leverages {@link org.apache.iceberg.ExpireSnapshots} to expire snapshots and + * then uses metadata tables to find files that can be safely deleted. This is done by anti-joining + * two Datasets that contain all manifest and content files before and after the expiration. The + * snapshot expiration will be fully committed before any deletes are issued. + * + *

This operation performs a shuffle so the parallelism can be controlled through + * 'spark.sql.shuffle.partitions'. + * + *

Deletes are still performed locally after retrieving the results from the Spark executors. + */ +@SuppressWarnings("UnnecessaryAnonymousClass") +public class ExpireSnapshotsSparkAction extends BaseSparkAction + implements ExpireSnapshots { + + public static final String STREAM_RESULTS = "stream-results"; + public static final boolean STREAM_RESULTS_DEFAULT = false; + + private static final Logger LOG = LoggerFactory.getLogger(ExpireSnapshotsSparkAction.class); + + private final Table table; + private final TableOperations ops; + + private final Set expiredSnapshotIds = Sets.newHashSet(); + private Long expireOlderThanValue = null; + private Integer retainLastValue = null; + private Consumer deleteFunc = null; + private ExecutorService deleteExecutorService = null; + private Dataset expiredFileDS = null; + + ExpireSnapshotsSparkAction(SparkSession spark, Table table) { + super(spark); + this.table = table; + this.ops = ((HasTableOperations) table).operations(); + + ValidationException.check( + PropertyUtil.propertyAsBoolean(table.properties(), GC_ENABLED, GC_ENABLED_DEFAULT), + "Cannot expire snapshots: GC is disabled (deleting files may corrupt other tables)"); + } + + @Override + protected ExpireSnapshotsSparkAction self() { + return this; + } + + @Override + public ExpireSnapshotsSparkAction executeDeleteWith(ExecutorService executorService) { + this.deleteExecutorService = executorService; + return this; + } + + @Override + public ExpireSnapshotsSparkAction expireSnapshotId(long snapshotId) { + expiredSnapshotIds.add(snapshotId); + return this; + } + + @Override + public ExpireSnapshotsSparkAction expireOlderThan(long timestampMillis) { + this.expireOlderThanValue = timestampMillis; + return this; + } + + @Override + public ExpireSnapshotsSparkAction retainLast(int numSnapshots) { + Preconditions.checkArgument( + 1 <= numSnapshots, + "Number of snapshots to retain must be at least 1, cannot be: %s", + numSnapshots); + this.retainLastValue = numSnapshots; + return this; + } + + @Override + public ExpireSnapshotsSparkAction deleteWith(Consumer newDeleteFunc) { + this.deleteFunc = newDeleteFunc; + return this; + } + + /** + * Expires snapshots and commits the changes to the table, returning a Dataset of files to delete. + * + *

This does not delete data files. To delete data files, run {@link #execute()}. + * + *

This may be called before or after {@link #execute()} to return the expired files. + * + * @return a Dataset of files that are no longer referenced by the table + */ + public Dataset expireFiles() { + if (expiredFileDS == null) { + // fetch metadata before expiration + TableMetadata originalMetadata = ops.current(); + + // perform expiration + org.apache.iceberg.ExpireSnapshots expireSnapshots = table.expireSnapshots(); + + for (long id : expiredSnapshotIds) { + expireSnapshots = expireSnapshots.expireSnapshotId(id); + } + + if (expireOlderThanValue != null) { + expireSnapshots = expireSnapshots.expireOlderThan(expireOlderThanValue); + } + + if (retainLastValue != null) { + expireSnapshots = expireSnapshots.retainLast(retainLastValue); + } + + expireSnapshots.cleanExpiredFiles(false).commit(); + + // fetch valid files after expiration + TableMetadata updatedMetadata = ops.refresh(); + Dataset validFileDS = fileDS(updatedMetadata); + + // fetch files referenced by expired snapshots + Set deletedSnapshotIds = findExpiredSnapshotIds(originalMetadata, updatedMetadata); + Dataset deleteCandidateFileDS = fileDS(originalMetadata, deletedSnapshotIds); + + // determine expired files + this.expiredFileDS = deleteCandidateFileDS.except(validFileDS); + } + + return expiredFileDS; + } + + @Override + public ExpireSnapshots.Result execute() { + JobGroupInfo info = newJobGroupInfo("EXPIRE-SNAPSHOTS", jobDesc()); + return withJobGroupInfo(info, this::doExecute); + } + + private String jobDesc() { + List options = Lists.newArrayList(); + + if (expireOlderThanValue != null) { + options.add("older_than=" + expireOlderThanValue); + } + + if (retainLastValue != null) { + options.add("retain_last=" + retainLastValue); + } + + if (!expiredSnapshotIds.isEmpty()) { + Long first = expiredSnapshotIds.stream().findFirst().get(); + if (expiredSnapshotIds.size() > 1) { + options.add( + String.format("snapshot_ids: %s (%s more...)", first, expiredSnapshotIds.size() - 1)); + } else { + options.add(String.format("snapshot_id: %s", first)); + } + } + + return String.format("Expiring snapshots (%s) in %s", COMMA_JOINER.join(options), table.name()); + } + + private ExpireSnapshots.Result doExecute() { + if (streamResults()) { + return deleteFiles(expireFiles().toLocalIterator()); + } else { + return deleteFiles(expireFiles().collectAsList().iterator()); + } + } + + private boolean streamResults() { + return PropertyUtil.propertyAsBoolean(options(), STREAM_RESULTS, STREAM_RESULTS_DEFAULT); + } + + private Dataset fileDS(TableMetadata metadata) { + return fileDS(metadata, null); + } + + private Dataset fileDS(TableMetadata metadata, Set snapshotIds) { + Table staticTable = newStaticTable(metadata, table.io()); + return contentFileDS(staticTable, snapshotIds) + .union(manifestDS(staticTable, snapshotIds)) + .union(manifestListDS(staticTable, snapshotIds)) + .union(statisticsFileDS(staticTable, snapshotIds)); + } + + private Set findExpiredSnapshotIds( + TableMetadata originalMetadata, TableMetadata updatedMetadata) { + Set retainedSnapshots = + updatedMetadata.snapshots().stream().map(Snapshot::snapshotId).collect(Collectors.toSet()); + return originalMetadata.snapshots().stream() + .map(Snapshot::snapshotId) + .filter(id -> !retainedSnapshots.contains(id)) + .collect(Collectors.toSet()); + } + + private ExpireSnapshots.Result deleteFiles(Iterator files) { + DeleteSummary summary; + if (deleteFunc == null && table.io() instanceof SupportsBulkOperations) { + summary = deleteFiles((SupportsBulkOperations) table.io(), files); + } else { + + if (deleteFunc == null) { + LOG.info( + "Table IO {} does not support bulk operations. Using non-bulk deletes.", + table.io().getClass().getName()); + summary = deleteFiles(deleteExecutorService, table.io()::deleteFile, files); + } else { + LOG.info("Custom delete function provided. Using non-bulk deletes"); + summary = deleteFiles(deleteExecutorService, deleteFunc, files); + } + } + + LOG.info("Deleted {} total files", summary.totalFilesCount()); + + return ImmutableExpireSnapshots.Result.builder() + .deletedDataFilesCount(summary.dataFilesCount()) + .deletedPositionDeleteFilesCount(summary.positionDeleteFilesCount()) + .deletedEqualityDeleteFilesCount(summary.equalityDeleteFilesCount()) + .deletedManifestsCount(summary.manifestsCount()) + .deletedManifestListsCount(summary.manifestListsCount()) + .deletedStatisticsFilesCount(summary.statisticsFilesCount()) + .build(); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/FileInfo.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/FileInfo.java new file mode 100644 index 000000000000..51ff7c80fd18 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/FileInfo.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; + +public class FileInfo { + public static final Encoder ENCODER = Encoders.bean(FileInfo.class); + + private String path; + private String type; + + public FileInfo(String path, String type) { + this.path = path; + this.type = type; + } + + public FileInfo() {} + + public String getPath() { + return path; + } + + public void setPath(String path) { + this.path = path; + } + + public String getType() { + return type; + } + + public void setType(String type) { + this.type = type; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/ManifestFileBean.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/ManifestFileBean.java new file mode 100644 index 000000000000..11ad834244ed --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/ManifestFileBean.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.io.Serializable; +import java.nio.ByteBuffer; +import java.util.List; +import org.apache.iceberg.ManifestContent; +import org.apache.iceberg.ManifestFile; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; + +/** A serializable bean that contains a bare minimum to read a manifest. */ +public class ManifestFileBean implements ManifestFile, Serializable { + public static final Encoder ENCODER = Encoders.bean(ManifestFileBean.class); + + private String path = null; + private Long length = null; + private Integer partitionSpecId = null; + private Long addedSnapshotId = null; + private Integer content = null; + private Long sequenceNumber = null; + + public static ManifestFileBean fromManifest(ManifestFile manifest) { + ManifestFileBean bean = new ManifestFileBean(); + + bean.setPath(manifest.path()); + bean.setLength(manifest.length()); + bean.setPartitionSpecId(manifest.partitionSpecId()); + bean.setAddedSnapshotId(manifest.snapshotId()); + bean.setContent(manifest.content().id()); + bean.setSequenceNumber(manifest.sequenceNumber()); + + return bean; + } + + public String getPath() { + return path; + } + + public void setPath(String path) { + this.path = path; + } + + public Long getLength() { + return length; + } + + public void setLength(Long length) { + this.length = length; + } + + public Integer getPartitionSpecId() { + return partitionSpecId; + } + + public void setPartitionSpecId(Integer partitionSpecId) { + this.partitionSpecId = partitionSpecId; + } + + public Long getAddedSnapshotId() { + return addedSnapshotId; + } + + public void setAddedSnapshotId(Long addedSnapshotId) { + this.addedSnapshotId = addedSnapshotId; + } + + public Integer getContent() { + return content; + } + + public void setContent(Integer content) { + this.content = content; + } + + public Long getSequenceNumber() { + return sequenceNumber; + } + + public void setSequenceNumber(Long sequenceNumber) { + this.sequenceNumber = sequenceNumber; + } + + @Override + public String path() { + return path; + } + + @Override + public long length() { + return length; + } + + @Override + public int partitionSpecId() { + return partitionSpecId; + } + + @Override + public ManifestContent content() { + return ManifestContent.fromId(content); + } + + @Override + public long sequenceNumber() { + return sequenceNumber; + } + + @Override + public long minSequenceNumber() { + return 0; + } + + @Override + public Long snapshotId() { + return addedSnapshotId; + } + + @Override + public Integer addedFilesCount() { + return null; + } + + @Override + public Long addedRowsCount() { + return null; + } + + @Override + public Integer existingFilesCount() { + return null; + } + + @Override + public Long existingRowsCount() { + return null; + } + + @Override + public Integer deletedFilesCount() { + return null; + } + + @Override + public Long deletedRowsCount() { + return null; + } + + @Override + public List partitions() { + return null; + } + + @Override + public ByteBuffer keyMetadata() { + return null; + } + + @Override + public ManifestFile copy() { + throw new UnsupportedOperationException("Cannot copy"); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/MigrateTableSparkAction.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/MigrateTableSparkAction.java new file mode 100644 index 000000000000..bdffeb465405 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/MigrateTableSparkAction.java @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.util.Map; +import java.util.concurrent.ExecutorService; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.ImmutableMigrateTable; +import org.apache.iceberg.actions.MigrateTable; +import org.apache.iceberg.exceptions.AlreadyExistsException; +import org.apache.iceberg.exceptions.NoSuchTableException; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.JobGroupInfo; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.iceberg.spark.SparkTableUtil; +import org.apache.iceberg.spark.source.StagedSparkTable; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.TableIdentifier; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.StagingTableCatalog; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.Some; +import scala.collection.JavaConverters; + +/** + * Takes a Spark table in the source catalog and attempts to transform it into an Iceberg table in + * the same location with the same identifier. Once complete the identifier which previously + * referred to a non-Iceberg table will refer to the newly migrated Iceberg table. + */ +public class MigrateTableSparkAction extends BaseTableCreationSparkAction + implements MigrateTable { + + private static final Logger LOG = LoggerFactory.getLogger(MigrateTableSparkAction.class); + private static final String BACKUP_SUFFIX = "_BACKUP_"; + + private final StagingTableCatalog destCatalog; + private final Identifier destTableIdent; + + private Identifier backupIdent; + private boolean dropBackup = false; + private ExecutorService executorService; + + MigrateTableSparkAction( + SparkSession spark, CatalogPlugin sourceCatalog, Identifier sourceTableIdent) { + super(spark, sourceCatalog, sourceTableIdent); + this.destCatalog = checkDestinationCatalog(sourceCatalog); + this.destTableIdent = sourceTableIdent; + String backupName = sourceTableIdent.name() + BACKUP_SUFFIX; + this.backupIdent = Identifier.of(sourceTableIdent.namespace(), backupName); + } + + @Override + protected MigrateTableSparkAction self() { + return this; + } + + @Override + protected StagingTableCatalog destCatalog() { + return destCatalog; + } + + @Override + protected Identifier destTableIdent() { + return destTableIdent; + } + + @Override + public MigrateTableSparkAction tableProperties(Map properties) { + setProperties(properties); + return this; + } + + @Override + public MigrateTableSparkAction tableProperty(String property, String value) { + setProperty(property, value); + return this; + } + + @Override + public MigrateTableSparkAction dropBackup() { + this.dropBackup = true; + return this; + } + + @Override + public MigrateTableSparkAction backupTableName(String tableName) { + this.backupIdent = Identifier.of(sourceTableIdent().namespace(), tableName); + return this; + } + + @Override + public MigrateTableSparkAction executeWith(ExecutorService service) { + this.executorService = service; + return this; + } + + @Override + public MigrateTable.Result execute() { + String desc = String.format("Migrating table %s", destTableIdent().toString()); + JobGroupInfo info = newJobGroupInfo("MIGRATE-TABLE", desc); + return withJobGroupInfo(info, this::doExecute); + } + + private MigrateTable.Result doExecute() { + LOG.info("Starting the migration of {} to Iceberg", sourceTableIdent()); + + // move the source table to a new name, halting all modifications and allowing us to stage + // the creation of a new Iceberg table in its place + renameAndBackupSourceTable(); + + StagedSparkTable stagedTable = null; + Table icebergTable; + boolean threw = true; + try { + LOG.info("Staging a new Iceberg table {}", destTableIdent()); + stagedTable = stageDestTable(); + icebergTable = stagedTable.table(); + + LOG.info("Ensuring {} has a valid name mapping", destTableIdent()); + ensureNameMappingPresent(icebergTable); + + Some backupNamespace = Some.apply(backupIdent.namespace()[0]); + TableIdentifier v1BackupIdent = new TableIdentifier(backupIdent.name(), backupNamespace); + String stagingLocation = getMetadataLocation(icebergTable); + LOG.info("Generating Iceberg metadata for {} in {}", destTableIdent(), stagingLocation); + SparkTableUtil.importSparkTable( + spark(), v1BackupIdent, icebergTable, stagingLocation, executorService); + + LOG.info("Committing staged changes to {}", destTableIdent()); + stagedTable.commitStagedChanges(); + threw = false; + } finally { + if (threw) { + LOG.error( + "Failed to perform the migration, aborting table creation and restoring the original table"); + + restoreSourceTable(); + + if (stagedTable != null) { + try { + stagedTable.abortStagedChanges(); + } catch (Exception abortException) { + LOG.error("Cannot abort staged changes", abortException); + } + } + } else if (dropBackup) { + dropBackupTable(); + } + } + + Snapshot snapshot = icebergTable.currentSnapshot(); + long migratedDataFilesCount = + Long.parseLong(snapshot.summary().get(SnapshotSummary.TOTAL_DATA_FILES_PROP)); + LOG.info( + "Successfully loaded Iceberg metadata for {} files to {}", + migratedDataFilesCount, + destTableIdent()); + return ImmutableMigrateTable.Result.builder() + .migratedDataFilesCount(migratedDataFilesCount) + .build(); + } + + @Override + protected Map destTableProps() { + Map properties = Maps.newHashMap(); + + // copy over relevant source table props + properties.putAll(JavaConverters.mapAsJavaMapConverter(v1SourceTable().properties()).asJava()); + EXCLUDED_PROPERTIES.forEach(properties::remove); + + // set default and user-provided props + properties.put(TableCatalog.PROP_PROVIDER, "iceberg"); + properties.putAll(additionalProperties()); + + // make sure we mark this table as migrated + properties.put("migrated", "true"); + + // inherit the source table location + properties.putIfAbsent(LOCATION, sourceTableLocation()); + + return properties; + } + + @Override + protected TableCatalog checkSourceCatalog(CatalogPlugin catalog) { + // currently the import code relies on being able to look up the table in the session catalog + Preconditions.checkArgument( + catalog instanceof SparkSessionCatalog, + "Cannot migrate a table from a non-Iceberg Spark Session Catalog. Found %s of class %s as the source catalog.", + catalog.name(), + catalog.getClass().getName()); + + return (TableCatalog) catalog; + } + + private void renameAndBackupSourceTable() { + try { + LOG.info("Renaming {} as {} for backup", sourceTableIdent(), backupIdent); + destCatalog().renameTable(sourceTableIdent(), backupIdent); + + } catch (org.apache.spark.sql.catalyst.analysis.NoSuchTableException e) { + throw new NoSuchTableException("Cannot find source table %s", sourceTableIdent()); + + } catch (org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException e) { + throw new AlreadyExistsException( + "Cannot rename %s as %s for backup. The backup table already exists.", + sourceTableIdent(), backupIdent); + } + } + + private void restoreSourceTable() { + try { + LOG.info("Restoring {} from {}", sourceTableIdent(), backupIdent); + destCatalog().renameTable(backupIdent, sourceTableIdent()); + + } catch (org.apache.spark.sql.catalyst.analysis.NoSuchTableException e) { + LOG.error( + "Cannot restore the original table, the backup table {} cannot be found", backupIdent, e); + + } catch (org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException e) { + LOG.error( + "Cannot restore the original table, a table with the original name exists. " + + "Use the backup table {} to restore the original table manually.", + backupIdent, + e); + } + } + + private void dropBackupTable() { + try { + destCatalog().dropTable(backupIdent); + } catch (Exception e) { + LOG.error( + "Cannot drop the backup table {}, after the migration is completed.", backupIdent, e); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/NDVSketchUtil.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/NDVSketchUtil.java new file mode 100644 index 000000000000..97fdc9102a37 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/NDVSketchUtil.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.nio.ByteBuffer; +import java.util.List; +import org.apache.datasketches.memory.Memory; +import org.apache.datasketches.theta.CompactSketch; +import org.apache.datasketches.theta.Sketch; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.puffin.Blob; +import org.apache.iceberg.puffin.PuffinCompressionCodec; +import org.apache.iceberg.puffin.StandardBlobTypes; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkTableUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.internal.ExpressionColumnNode; +import org.apache.spark.sql.stats.ThetaSketchAgg; + +public class NDVSketchUtil { + + private NDVSketchUtil() {} + + public static final String APACHE_DATASKETCHES_THETA_V1_NDV_PROPERTY = "ndv"; + + static List generateBlobs( + SparkSession spark, Table table, Snapshot snapshot, List columns) { + Row sketches = computeNDVSketches(spark, table, snapshot, columns); + Schema schema = table.schemas().get(snapshot.schemaId()); + List blobs = Lists.newArrayList(); + for (int i = 0; i < columns.size(); i++) { + Types.NestedField field = schema.findField(columns.get(i)); + Sketch sketch = CompactSketch.wrap(Memory.wrap((byte[]) sketches.get(i))); + blobs.add(toBlob(field, sketch, snapshot)); + } + return blobs; + } + + private static Blob toBlob(Types.NestedField field, Sketch sketch, Snapshot snapshot) { + return new Blob( + StandardBlobTypes.APACHE_DATASKETCHES_THETA_V1, + ImmutableList.of(field.fieldId()), + snapshot.snapshotId(), + snapshot.sequenceNumber(), + ByteBuffer.wrap(sketch.toByteArray()), + PuffinCompressionCodec.ZSTD, + ImmutableMap.of( + APACHE_DATASKETCHES_THETA_V1_NDV_PROPERTY, + String.valueOf((long) sketch.getEstimate()))); + } + + private static Row computeNDVSketches( + SparkSession spark, Table table, Snapshot snapshot, List colNames) { + Dataset inputDF = SparkTableUtil.loadTable(spark, table, snapshot.snapshotId()); + return inputDF.select(toAggColumns(colNames)).first(); + } + + private static Column[] toAggColumns(List colNames) { + return colNames.stream().map(NDVSketchUtil::toAggColumn).toArray(Column[]::new); + } + + private static Column toAggColumn(String colName) { + ThetaSketchAgg agg = new ThetaSketchAgg(colName); + return new Column(ExpressionColumnNode.apply(agg.toAggregateExpression())); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/RemoveDanglingDeletesSparkAction.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/RemoveDanglingDeletesSparkAction.java new file mode 100644 index 000000000000..1474ec0e3eef --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/RemoveDanglingDeletesSparkAction.java @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.min; + +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.MetadataTableType; +import org.apache.iceberg.Partitioning; +import org.apache.iceberg.RewriteFiles; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.ImmutableRemoveDanglingDeleteFiles; +import org.apache.iceberg.actions.RemoveDanglingDeleteFiles; +import org.apache.iceberg.spark.JobGroupInfo; +import org.apache.iceberg.spark.SparkDeleteFile; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.StructType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * An action that removes dangling delete files from the current snapshot. A delete file is dangling + * if its deletes no longer applies to any live data files. + * + *

The following dangling delete files are removed: + * + *

    + *
  • Position delete files with a data sequence number less than that of any data file in the + * same partition + *
  • Equality delete files with a data sequence number less than or equal to that of any data + * file in the same partition + *
+ */ +class RemoveDanglingDeletesSparkAction + extends BaseSnapshotUpdateSparkAction + implements RemoveDanglingDeleteFiles { + + private static final Logger LOG = LoggerFactory.getLogger(RemoveDanglingDeletesSparkAction.class); + private final Table table; + + protected RemoveDanglingDeletesSparkAction(SparkSession spark, Table table) { + super(spark); + this.table = table; + } + + @Override + protected RemoveDanglingDeletesSparkAction self() { + return this; + } + + public Result execute() { + if (table.specs().size() == 1 && table.spec().isUnpartitioned()) { + // ManifestFilterManager already performs this table-wide delete on each commit + return ImmutableRemoveDanglingDeleteFiles.Result.builder() + .removedDeleteFiles(Collections.emptyList()) + .build(); + } + + String desc = String.format("Removing dangling delete files in %s", table.name()); + JobGroupInfo info = newJobGroupInfo("REMOVE-DELETES", desc); + return withJobGroupInfo(info, this::doExecute); + } + + Result doExecute() { + RewriteFiles rewriteFiles = table.newRewrite(); + List danglingDeletes = findDanglingDeletes(); + for (DeleteFile deleteFile : danglingDeletes) { + LOG.debug("Removing dangling delete file {}", deleteFile.location()); + rewriteFiles.deleteFile(deleteFile); + } + + if (!danglingDeletes.isEmpty()) { + commit(rewriteFiles); + } + + return ImmutableRemoveDanglingDeleteFiles.Result.builder() + .removedDeleteFiles(danglingDeletes) + .build(); + } + + /** + * Dangling delete files can be identified with following steps + * + *
    + *
  1. Group data files by partition keys and find the minimum data sequence number in each + * group. + *
  2. Left outer join delete files with partition-grouped data files on partition keys. + *
  3. Find dangling deletes by comparing each delete file's sequence number to its partition's + * minimum data sequence number. + *
  4. Collect results row to driver and use {@link SparkDeleteFile SparkDeleteFile} to wrap + * rows to valid delete files + *
+ */ + private List findDanglingDeletes() { + Dataset minSequenceNumberByPartition = + loadMetadataTable(table, MetadataTableType.ENTRIES) + // find live data files + .filter("data_file.content == 0 AND status < 2") + .selectExpr( + "data_file.partition as partition", + "data_file.spec_id as spec_id", + "sequence_number") + .groupBy("partition", "spec_id") + .agg(min("sequence_number")) + .toDF("grouped_partition", "grouped_spec_id", "min_data_sequence_number"); + + Dataset deleteEntries = + loadMetadataTable(table, MetadataTableType.ENTRIES) + // find live delete files + .filter("data_file.content != 0 AND status < 2"); + + Column joinOnPartition = + deleteEntries + .col("data_file.spec_id") + .equalTo(minSequenceNumberByPartition.col("grouped_spec_id")) + .and( + deleteEntries + .col("data_file.partition") + .equalTo(minSequenceNumberByPartition.col("grouped_partition"))); + + Column filterOnDanglingDeletes = + col("min_data_sequence_number") + // delete fies without any data files in partition + .isNull() + // position delete files without any applicable data files in partition + .or( + col("data_file.content") + .equalTo("1") + .and(col("sequence_number").$less(col("min_data_sequence_number")))) + // equality delete files without any applicable data files in the partition + .or( + col("data_file.content") + .equalTo("2") + .and(col("sequence_number").$less$eq(col("min_data_sequence_number")))); + + Dataset danglingDeletes = + deleteEntries + .join(minSequenceNumberByPartition, joinOnPartition, "left") + .filter(filterOnDanglingDeletes) + .select("data_file.*"); + return danglingDeletes.collectAsList().stream() + // map on driver because SparkDeleteFile is not serializable + .map(row -> deleteFileWrapper(danglingDeletes.schema(), row)) + .collect(Collectors.toList()); + } + + private DeleteFile deleteFileWrapper(StructType sparkFileType, Row row) { + int specId = row.getInt(row.fieldIndex("spec_id")); + Types.StructType combinedFileType = DataFile.getType(Partitioning.partitionType(table)); + // Set correct spec id + Types.StructType projection = DataFile.getType(table.specs().get(specId).partitionType()); + return new SparkDeleteFile(combinedFileType, projection, sparkFileType).wrap(row); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/RewriteDataFilesSparkAction.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/RewriteDataFilesSparkAction.java new file mode 100644 index 000000000000..4e381a7bd362 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/RewriteDataFilesSparkAction.java @@ -0,0 +1,545 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.io.IOException; +import java.math.RoundingMode; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.RewriteJobOrder; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.FileRewriter; +import org.apache.iceberg.actions.ImmutableRewriteDataFiles; +import org.apache.iceberg.actions.ImmutableRewriteDataFiles.Result.Builder; +import org.apache.iceberg.actions.RewriteDataFiles; +import org.apache.iceberg.actions.RewriteDataFilesCommitManager; +import org.apache.iceberg.actions.RewriteFileGroup; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.exceptions.CommitFailedException; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Queues; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.relocated.com.google.common.math.IntMath; +import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors; +import org.apache.iceberg.relocated.com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.apache.iceberg.types.Types.StructType; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.StructLikeMap; +import org.apache.iceberg.util.Tasks; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.internal.SQLConf; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class RewriteDataFilesSparkAction + extends BaseSnapshotUpdateSparkAction implements RewriteDataFiles { + + private static final Logger LOG = LoggerFactory.getLogger(RewriteDataFilesSparkAction.class); + private static final Set VALID_OPTIONS = + ImmutableSet.of( + MAX_CONCURRENT_FILE_GROUP_REWRITES, + MAX_FILE_GROUP_SIZE_BYTES, + PARTIAL_PROGRESS_ENABLED, + PARTIAL_PROGRESS_MAX_COMMITS, + PARTIAL_PROGRESS_MAX_FAILED_COMMITS, + TARGET_FILE_SIZE_BYTES, + USE_STARTING_SEQUENCE_NUMBER, + REWRITE_JOB_ORDER, + OUTPUT_SPEC_ID, + REMOVE_DANGLING_DELETES); + + private static final RewriteDataFilesSparkAction.Result EMPTY_RESULT = + ImmutableRewriteDataFiles.Result.builder().rewriteResults(ImmutableList.of()).build(); + + private final Table table; + + private Expression filter = Expressions.alwaysTrue(); + private int maxConcurrentFileGroupRewrites; + private int maxCommits; + private int maxFailedCommits; + private boolean partialProgressEnabled; + private boolean removeDanglingDeletes; + private boolean useStartingSequenceNumber; + private RewriteJobOrder rewriteJobOrder; + private FileRewriter rewriter = null; + + RewriteDataFilesSparkAction(SparkSession spark, Table table) { + super(spark.cloneSession()); + // Disable Adaptive Query Execution as this may change the output partitioning of our write + spark().conf().set(SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), false); + this.table = table; + } + + @Override + protected RewriteDataFilesSparkAction self() { + return this; + } + + @Override + public RewriteDataFilesSparkAction binPack() { + Preconditions.checkArgument( + rewriter == null, "Must use only one rewriter type (bin-pack, sort, zorder)"); + this.rewriter = new SparkBinPackDataRewriter(spark(), table); + return this; + } + + @Override + public RewriteDataFilesSparkAction sort(SortOrder sortOrder) { + Preconditions.checkArgument( + rewriter == null, "Must use only one rewriter type (bin-pack, sort, zorder)"); + this.rewriter = new SparkSortDataRewriter(spark(), table, sortOrder); + return this; + } + + @Override + public RewriteDataFilesSparkAction sort() { + Preconditions.checkArgument( + rewriter == null, "Must use only one rewriter type (bin-pack, sort, zorder)"); + this.rewriter = new SparkSortDataRewriter(spark(), table); + return this; + } + + @Override + public RewriteDataFilesSparkAction zOrder(String... columnNames) { + Preconditions.checkArgument( + rewriter == null, "Must use only one rewriter type (bin-pack, sort, zorder)"); + this.rewriter = new SparkZOrderDataRewriter(spark(), table, Arrays.asList(columnNames)); + return this; + } + + @Override + public RewriteDataFilesSparkAction filter(Expression expression) { + filter = Expressions.and(filter, expression); + return this; + } + + @Override + public RewriteDataFiles.Result execute() { + if (table.currentSnapshot() == null) { + return EMPTY_RESULT; + } + + long startingSnapshotId = table.currentSnapshot().snapshotId(); + + // Default to BinPack if no strategy selected + if (this.rewriter == null) { + this.rewriter = new SparkBinPackDataRewriter(spark(), table); + } + + validateAndInitOptions(); + + StructLikeMap>> fileGroupsByPartition = + planFileGroups(startingSnapshotId); + RewriteExecutionContext ctx = new RewriteExecutionContext(fileGroupsByPartition); + + if (ctx.totalGroupCount() == 0) { + LOG.info("Nothing found to rewrite in {}", table.name()); + return EMPTY_RESULT; + } + + Stream groupStream = toGroupStream(ctx, fileGroupsByPartition); + + Builder resultBuilder = + partialProgressEnabled + ? doExecuteWithPartialProgress(ctx, groupStream, commitManager(startingSnapshotId)) + : doExecute(ctx, groupStream, commitManager(startingSnapshotId)); + + if (removeDanglingDeletes) { + RemoveDanglingDeletesSparkAction action = + new RemoveDanglingDeletesSparkAction(spark(), table); + int removedCount = Iterables.size(action.execute().removedDeleteFiles()); + resultBuilder.removedDeleteFilesCount(removedCount); + } + return resultBuilder.build(); + } + + StructLikeMap>> planFileGroups(long startingSnapshotId) { + CloseableIterable fileScanTasks = + table + .newScan() + .useSnapshot(startingSnapshotId) + .filter(filter) + .ignoreResiduals() + .planFiles(); + + try { + StructType partitionType = table.spec().partitionType(); + StructLikeMap> filesByPartition = + groupByPartition(partitionType, fileScanTasks); + return fileGroupsByPartition(filesByPartition); + } finally { + try { + fileScanTasks.close(); + } catch (IOException io) { + LOG.error("Cannot properly close file iterable while planning for rewrite", io); + } + } + } + + private StructLikeMap> groupByPartition( + StructType partitionType, Iterable tasks) { + StructLikeMap> filesByPartition = StructLikeMap.create(partitionType); + StructLike emptyStruct = GenericRecord.create(partitionType); + + for (FileScanTask task : tasks) { + // If a task uses an incompatible partition spec the data inside could contain values + // which belong to multiple partitions in the current spec. Treating all such files as + // un-partitioned and grouping them together helps to minimize new files made. + StructLike taskPartition = + task.file().specId() == table.spec().specId() ? task.file().partition() : emptyStruct; + + List files = filesByPartition.get(taskPartition); + if (files == null) { + files = Lists.newArrayList(); + } + + files.add(task); + filesByPartition.put(taskPartition, files); + } + return filesByPartition; + } + + private StructLikeMap>> fileGroupsByPartition( + StructLikeMap> filesByPartition) { + return filesByPartition.transformValues(this::planFileGroups); + } + + private List> planFileGroups(List tasks) { + return ImmutableList.copyOf(rewriter.planFileGroups(tasks)); + } + + @VisibleForTesting + RewriteFileGroup rewriteFiles(RewriteExecutionContext ctx, RewriteFileGroup fileGroup) { + String desc = jobDesc(fileGroup, ctx); + Set addedFiles = + withJobGroupInfo( + newJobGroupInfo("REWRITE-DATA-FILES", desc), + () -> rewriter.rewrite(fileGroup.fileScans())); + + fileGroup.setOutputFiles(addedFiles); + LOG.info("Rewrite Files Ready to be Committed - {}", desc); + return fileGroup; + } + + private ExecutorService rewriteService() { + return MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) + Executors.newFixedThreadPool( + maxConcurrentFileGroupRewrites, + new ThreadFactoryBuilder().setNameFormat("Rewrite-Service-%d").build())); + } + + @VisibleForTesting + RewriteDataFilesCommitManager commitManager(long startingSnapshotId) { + return new RewriteDataFilesCommitManager( + table, startingSnapshotId, useStartingSequenceNumber, commitSummary()); + } + + private Builder doExecute( + RewriteExecutionContext ctx, + Stream groupStream, + RewriteDataFilesCommitManager commitManager) { + ExecutorService rewriteService = rewriteService(); + + ConcurrentLinkedQueue rewrittenGroups = Queues.newConcurrentLinkedQueue(); + + Tasks.Builder rewriteTaskBuilder = + Tasks.foreach(groupStream) + .executeWith(rewriteService) + .stopOnFailure() + .noRetry() + .onFailure( + (fileGroup, exception) -> { + LOG.warn( + "Failure during rewrite process for group {}", fileGroup.info(), exception); + }); + + try { + rewriteTaskBuilder.run( + fileGroup -> { + rewrittenGroups.add(rewriteFiles(ctx, fileGroup)); + }); + } catch (Exception e) { + // At least one rewrite group failed, clean up all completed rewrites + LOG.error( + "Cannot complete rewrite, {} is not enabled and one of the file set groups failed to " + + "be rewritten. This error occurred during the writing of new files, not during the commit process. This " + + "indicates something is wrong that doesn't involve conflicts with other Iceberg operations. Enabling " + + "{} may help in this case but the root cause should be investigated. Cleaning up {} groups which finished " + + "being written.", + PARTIAL_PROGRESS_ENABLED, + PARTIAL_PROGRESS_ENABLED, + rewrittenGroups.size(), + e); + + Tasks.foreach(rewrittenGroups) + .suppressFailureWhenFinished() + .run(commitManager::abortFileGroup); + throw e; + } finally { + rewriteService.shutdown(); + } + + try { + commitManager.commitOrClean(Sets.newHashSet(rewrittenGroups)); + } catch (ValidationException | CommitFailedException e) { + String errorMessage = + String.format( + "Cannot commit rewrite because of a ValidationException or CommitFailedException. This usually means that " + + "this rewrite has conflicted with another concurrent Iceberg operation. To reduce the likelihood of " + + "conflicts, set %s which will break up the rewrite into multiple smaller commits controlled by %s. " + + "Separate smaller rewrite commits can succeed independently while any commits that conflict with " + + "another Iceberg operation will be ignored. This mode will create additional snapshots in the table " + + "history, one for each commit.", + PARTIAL_PROGRESS_ENABLED, PARTIAL_PROGRESS_MAX_COMMITS); + throw new RuntimeException(errorMessage, e); + } + + List rewriteResults = + rewrittenGroups.stream().map(RewriteFileGroup::asResult).collect(Collectors.toList()); + return ImmutableRewriteDataFiles.Result.builder().rewriteResults(rewriteResults); + } + + private Builder doExecuteWithPartialProgress( + RewriteExecutionContext ctx, + Stream groupStream, + RewriteDataFilesCommitManager commitManager) { + ExecutorService rewriteService = rewriteService(); + + // start commit service + int groupsPerCommit = IntMath.divide(ctx.totalGroupCount(), maxCommits, RoundingMode.CEILING); + RewriteDataFilesCommitManager.CommitService commitService = + commitManager.service(groupsPerCommit); + commitService.start(); + + Collection rewriteFailures = new ConcurrentLinkedQueue<>(); + // start rewrite tasks + Tasks.foreach(groupStream) + .suppressFailureWhenFinished() + .executeWith(rewriteService) + .noRetry() + .onFailure( + (fileGroup, exception) -> { + LOG.error("Failure during rewrite group {}", fileGroup.info(), exception); + rewriteFailures.add( + ImmutableRewriteDataFiles.FileGroupFailureResult.builder() + .info(fileGroup.info()) + .dataFilesCount(fileGroup.numFiles()) + .build()); + }) + .run(fileGroup -> commitService.offer(rewriteFiles(ctx, fileGroup))); + rewriteService.shutdown(); + + // stop commit service + commitService.close(); + + int failedCommits = maxCommits - commitService.succeededCommits(); + if (failedCommits > 0 && failedCommits <= maxFailedCommits) { + LOG.warn( + "{} is true but {} rewrite commits failed. Check the logs to determine why the individual " + + "commits failed. If this is persistent it may help to increase {} which will split the rewrite operation " + + "into smaller commits.", + PARTIAL_PROGRESS_ENABLED, + failedCommits, + PARTIAL_PROGRESS_MAX_COMMITS); + } else if (failedCommits > maxFailedCommits) { + String errorMessage = + String.format( + "%s is true but %d rewrite commits failed. This is more than the maximum allowed failures of %d. " + + "Check the logs to determine why the individual commits failed. If this is persistent it may help to " + + "increase %s which will split the rewrite operation into smaller commits.", + PARTIAL_PROGRESS_ENABLED, + failedCommits, + maxFailedCommits, + PARTIAL_PROGRESS_MAX_COMMITS); + throw new RuntimeException(errorMessage); + } + + return ImmutableRewriteDataFiles.Result.builder() + .rewriteResults(toRewriteResults(commitService.results())) + .rewriteFailures(rewriteFailures); + } + + Stream toGroupStream( + RewriteExecutionContext ctx, Map>> groupsByPartition) { + return groupsByPartition.entrySet().stream() + .filter(e -> !e.getValue().isEmpty()) + .flatMap( + e -> { + StructLike partition = e.getKey(); + List> scanGroups = e.getValue(); + return scanGroups.stream().map(tasks -> newRewriteGroup(ctx, partition, tasks)); + }) + .sorted(RewriteFileGroup.comparator(rewriteJobOrder)); + } + + private RewriteFileGroup newRewriteGroup( + RewriteExecutionContext ctx, StructLike partition, List tasks) { + int globalIndex = ctx.currentGlobalIndex(); + int partitionIndex = ctx.currentPartitionIndex(partition); + FileGroupInfo info = + ImmutableRewriteDataFiles.FileGroupInfo.builder() + .globalIndex(globalIndex) + .partitionIndex(partitionIndex) + .partition(partition) + .build(); + return new RewriteFileGroup(info, tasks); + } + + private Iterable toRewriteResults(List commitResults) { + return commitResults.stream().map(RewriteFileGroup::asResult).collect(Collectors.toList()); + } + + void validateAndInitOptions() { + Set validOptions = Sets.newHashSet(rewriter.validOptions()); + validOptions.addAll(VALID_OPTIONS); + + Set invalidKeys = Sets.newHashSet(options().keySet()); + invalidKeys.removeAll(validOptions); + + Preconditions.checkArgument( + invalidKeys.isEmpty(), + "Cannot use options %s, they are not supported by the action or the rewriter %s", + invalidKeys, + rewriter.description()); + + rewriter.init(options()); + + maxConcurrentFileGroupRewrites = + PropertyUtil.propertyAsInt( + options(), + MAX_CONCURRENT_FILE_GROUP_REWRITES, + MAX_CONCURRENT_FILE_GROUP_REWRITES_DEFAULT); + + maxCommits = + PropertyUtil.propertyAsInt( + options(), PARTIAL_PROGRESS_MAX_COMMITS, PARTIAL_PROGRESS_MAX_COMMITS_DEFAULT); + + maxFailedCommits = + PropertyUtil.propertyAsInt(options(), PARTIAL_PROGRESS_MAX_FAILED_COMMITS, maxCommits); + + partialProgressEnabled = + PropertyUtil.propertyAsBoolean( + options(), PARTIAL_PROGRESS_ENABLED, PARTIAL_PROGRESS_ENABLED_DEFAULT); + + useStartingSequenceNumber = + PropertyUtil.propertyAsBoolean( + options(), USE_STARTING_SEQUENCE_NUMBER, USE_STARTING_SEQUENCE_NUMBER_DEFAULT); + + removeDanglingDeletes = + PropertyUtil.propertyAsBoolean( + options(), REMOVE_DANGLING_DELETES, REMOVE_DANGLING_DELETES_DEFAULT); + + rewriteJobOrder = + RewriteJobOrder.fromName( + PropertyUtil.propertyAsString(options(), REWRITE_JOB_ORDER, REWRITE_JOB_ORDER_DEFAULT)); + + Preconditions.checkArgument( + maxConcurrentFileGroupRewrites >= 1, + "Cannot set %s to %s, the value must be positive.", + MAX_CONCURRENT_FILE_GROUP_REWRITES, + maxConcurrentFileGroupRewrites); + + Preconditions.checkArgument( + !partialProgressEnabled || maxCommits > 0, + "Cannot set %s to %s, the value must be positive when %s is true", + PARTIAL_PROGRESS_MAX_COMMITS, + maxCommits, + PARTIAL_PROGRESS_ENABLED); + } + + private String jobDesc(RewriteFileGroup group, RewriteExecutionContext ctx) { + StructLike partition = group.info().partition(); + if (partition.size() > 0) { + return String.format( + "Rewriting %d files (%s, file group %d/%d, %s (%d/%d)) in %s", + group.rewrittenFiles().size(), + rewriter.description(), + group.info().globalIndex(), + ctx.totalGroupCount(), + partition, + group.info().partitionIndex(), + ctx.groupsInPartition(partition), + table.name()); + } else { + return String.format( + "Rewriting %d files (%s, file group %d/%d) in %s", + group.rewrittenFiles().size(), + rewriter.description(), + group.info().globalIndex(), + ctx.totalGroupCount(), + table.name()); + } + } + + @VisibleForTesting + static class RewriteExecutionContext { + private final StructLikeMap numGroupsByPartition; + private final int totalGroupCount; + private final Map partitionIndexMap; + private final AtomicInteger groupIndex; + + RewriteExecutionContext(StructLikeMap>> fileGroupsByPartition) { + this.numGroupsByPartition = fileGroupsByPartition.transformValues(List::size); + this.totalGroupCount = numGroupsByPartition.values().stream().reduce(Integer::sum).orElse(0); + this.partitionIndexMap = Maps.newConcurrentMap(); + this.groupIndex = new AtomicInteger(1); + } + + public int currentGlobalIndex() { + return groupIndex.getAndIncrement(); + } + + public int currentPartitionIndex(StructLike partition) { + return partitionIndexMap.merge(partition, 1, Integer::sum); + } + + public int groupsInPartition(StructLike partition) { + return numGroupsByPartition.get(partition); + } + + public int totalGroupCount() { + return totalGroupCount; + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/RewriteManifestsSparkAction.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/RewriteManifestsSparkAction.java new file mode 100644 index 000000000000..60e2b11881cb --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/RewriteManifestsSparkAction.java @@ -0,0 +1,553 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.apache.iceberg.MetadataTableType.ENTRIES; + +import java.io.Serializable; +import java.util.Iterator; +import java.util.List; +import java.util.UUID; +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.ManifestContent; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.ManifestFiles; +import org.apache.iceberg.ManifestWriter; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Partitioning; +import org.apache.iceberg.RollingManifestWriter; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.actions.ImmutableRewriteManifests; +import org.apache.iceberg.actions.RewriteManifests; +import org.apache.iceberg.exceptions.CleanableFailure; +import org.apache.iceberg.exceptions.CommitStateUnknownException; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.io.SupportsBulkOperations; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.JobGroupInfo; +import org.apache.iceberg.spark.SparkContentFile; +import org.apache.iceberg.spark.SparkDataFile; +import org.apache.iceberg.spark.SparkDeleteFile; +import org.apache.iceberg.spark.source.SerializableTableWithSize; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.ThreadPools; +import org.apache.spark.api.java.function.MapPartitionsFunction; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.StructType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * An action that rewrites manifests in a distributed manner and co-locates metadata for partitions. + * + *

By default, this action rewrites all manifests for the current partition spec and writes the + * result to the metadata folder. The behavior can be modified by passing a custom predicate to + * {@link #rewriteIf(Predicate)} and a custom spec ID to {@link #specId(int)}. In addition, there is + * a way to configure a custom location for staged manifests via {@link #stagingLocation(String)}. + * The provided staging location will be ignored if snapshot ID inheritance is enabled. In such + * cases, the manifests are always written to the metadata folder and committed without staging. + */ +public class RewriteManifestsSparkAction + extends BaseSnapshotUpdateSparkAction implements RewriteManifests { + + public static final String USE_CACHING = "use-caching"; + public static final boolean USE_CACHING_DEFAULT = false; + + private static final Logger LOG = LoggerFactory.getLogger(RewriteManifestsSparkAction.class); + private static final RewriteManifests.Result EMPTY_RESULT = + ImmutableRewriteManifests.Result.builder() + .rewrittenManifests(ImmutableList.of()) + .addedManifests(ImmutableList.of()) + .build(); + + private final Table table; + private final int formatVersion; + private final long targetManifestSizeBytes; + private final boolean shouldStageManifests; + + private PartitionSpec spec; + private Predicate predicate = manifest -> true; + private String outputLocation; + + RewriteManifestsSparkAction(SparkSession spark, Table table) { + super(spark); + this.table = table; + this.spec = table.spec(); + this.targetManifestSizeBytes = + PropertyUtil.propertyAsLong( + table.properties(), + TableProperties.MANIFEST_TARGET_SIZE_BYTES, + TableProperties.MANIFEST_TARGET_SIZE_BYTES_DEFAULT); + + // default the output location to the metadata location + TableOperations ops = ((HasTableOperations) table).operations(); + Path metadataFilePath = new Path(ops.metadataFileLocation("file")); + this.outputLocation = metadataFilePath.getParent().toString(); + + // use the current table format version for new manifests + this.formatVersion = ops.current().formatVersion(); + + boolean snapshotIdInheritanceEnabled = + PropertyUtil.propertyAsBoolean( + table.properties(), + TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, + TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED_DEFAULT); + this.shouldStageManifests = formatVersion == 1 && !snapshotIdInheritanceEnabled; + } + + @Override + protected RewriteManifestsSparkAction self() { + return this; + } + + @Override + public RewriteManifestsSparkAction specId(int specId) { + Preconditions.checkArgument(table.specs().containsKey(specId), "Invalid spec id %s", specId); + this.spec = table.specs().get(specId); + return this; + } + + @Override + public RewriteManifestsSparkAction rewriteIf(Predicate newPredicate) { + this.predicate = newPredicate; + return this; + } + + @Override + public RewriteManifestsSparkAction stagingLocation(String newStagingLocation) { + if (shouldStageManifests) { + this.outputLocation = newStagingLocation; + } else { + LOG.warn("Ignoring provided staging location as new manifests will be committed directly"); + } + return this; + } + + @Override + public RewriteManifests.Result execute() { + String desc = String.format("Rewriting manifests in %s", table.name()); + JobGroupInfo info = newJobGroupInfo("REWRITE-MANIFESTS", desc); + return withJobGroupInfo(info, this::doExecute); + } + + private RewriteManifests.Result doExecute() { + List rewrittenManifests = Lists.newArrayList(); + List addedManifests = Lists.newArrayList(); + + RewriteManifests.Result dataResult = rewriteManifests(ManifestContent.DATA); + Iterables.addAll(rewrittenManifests, dataResult.rewrittenManifests()); + Iterables.addAll(addedManifests, dataResult.addedManifests()); + + RewriteManifests.Result deletesResult = rewriteManifests(ManifestContent.DELETES); + Iterables.addAll(rewrittenManifests, deletesResult.rewrittenManifests()); + Iterables.addAll(addedManifests, deletesResult.addedManifests()); + + if (rewrittenManifests.isEmpty()) { + return EMPTY_RESULT; + } + + replaceManifests(rewrittenManifests, addedManifests); + + return ImmutableRewriteManifests.Result.builder() + .rewrittenManifests(rewrittenManifests) + .addedManifests(addedManifests) + .build(); + } + + private RewriteManifests.Result rewriteManifests(ManifestContent content) { + List matchingManifests = findMatchingManifests(content); + if (matchingManifests.isEmpty()) { + return EMPTY_RESULT; + } + + int targetNumManifests = targetNumManifests(totalSizeBytes(matchingManifests)); + if (targetNumManifests == 1 && matchingManifests.size() == 1) { + return EMPTY_RESULT; + } + + Dataset manifestEntryDF = buildManifestEntryDF(matchingManifests); + + List newManifests; + if (spec.isUnpartitioned()) { + newManifests = writeUnpartitionedManifests(content, manifestEntryDF, targetNumManifests); + } else { + newManifests = writePartitionedManifests(content, manifestEntryDF, targetNumManifests); + } + + return ImmutableRewriteManifests.Result.builder() + .rewrittenManifests(matchingManifests) + .addedManifests(newManifests) + .build(); + } + + private Dataset buildManifestEntryDF(List manifests) { + Dataset manifestDF = + spark() + .createDataset(Lists.transform(manifests, ManifestFile::path), Encoders.STRING()) + .toDF("manifest"); + + Dataset manifestEntryDF = + loadMetadataTable(table, ENTRIES) + .filter("status < 2") // select only live entries + .selectExpr( + "input_file_name() as manifest", + "snapshot_id", + "sequence_number", + "file_sequence_number", + "data_file"); + + Column joinCond = manifestDF.col("manifest").equalTo(manifestEntryDF.col("manifest")); + return manifestEntryDF + .join(manifestDF, joinCond, "left_semi") + .select("snapshot_id", "sequence_number", "file_sequence_number", "data_file"); + } + + private List writeUnpartitionedManifests( + ManifestContent content, Dataset manifestEntryDF, int numManifests) { + + WriteManifests writeFunc = newWriteManifestsFunc(content, manifestEntryDF.schema()); + Dataset transformedManifestEntryDF = manifestEntryDF.repartition(numManifests); + return writeFunc.apply(transformedManifestEntryDF).collectAsList(); + } + + private List writePartitionedManifests( + ManifestContent content, Dataset manifestEntryDF, int numManifests) { + + return withReusableDS( + manifestEntryDF, + df -> { + WriteManifests writeFunc = newWriteManifestsFunc(content, df.schema()); + Column partitionColumn = df.col("data_file.partition"); + Dataset transformedDF = repartitionAndSort(df, partitionColumn, numManifests); + return writeFunc.apply(transformedDF).collectAsList(); + }); + } + + private WriteManifests newWriteManifestsFunc(ManifestContent content, StructType sparkType) { + ManifestWriterFactory writers = manifestWriters(); + + StructType sparkFileType = (StructType) sparkType.apply("data_file").dataType(); + Types.StructType combinedFileType = DataFile.getType(Partitioning.partitionType(table)); + Types.StructType fileType = DataFile.getType(spec.partitionType()); + + if (content == ManifestContent.DATA) { + return new WriteDataManifests(writers, combinedFileType, fileType, sparkFileType); + } else { + return new WriteDeleteManifests(writers, combinedFileType, fileType, sparkFileType); + } + } + + private Dataset repartitionAndSort(Dataset df, Column col, int numPartitions) { + return df.repartitionByRange(numPartitions, col).sortWithinPartitions(col); + } + + private U withReusableDS(Dataset ds, Function, U> func) { + boolean useCaching = + PropertyUtil.propertyAsBoolean(options(), USE_CACHING, USE_CACHING_DEFAULT); + Dataset reusableDS = useCaching ? ds.cache() : ds; + + try { + return func.apply(reusableDS); + } finally { + if (useCaching) { + reusableDS.unpersist(false); + } + } + } + + private List findMatchingManifests(ManifestContent content) { + Snapshot currentSnapshot = table.currentSnapshot(); + + if (currentSnapshot == null) { + return ImmutableList.of(); + } + + List manifests = loadManifests(content, currentSnapshot); + + return manifests.stream() + .filter(manifest -> manifest.partitionSpecId() == spec.specId() && predicate.test(manifest)) + .collect(Collectors.toList()); + } + + private List loadManifests(ManifestContent content, Snapshot snapshot) { + switch (content) { + case DATA: + return snapshot.dataManifests(table.io()); + case DELETES: + return snapshot.deleteManifests(table.io()); + default: + throw new IllegalArgumentException("Unknown manifest content: " + content); + } + } + + private int targetNumManifests(long totalSizeBytes) { + return (int) ((totalSizeBytes + targetManifestSizeBytes - 1) / targetManifestSizeBytes); + } + + private long totalSizeBytes(Iterable manifests) { + long totalSizeBytes = 0L; + + for (ManifestFile manifest : manifests) { + ValidationException.check( + hasFileCounts(manifest), "No file counts in manifest: %s", manifest.path()); + totalSizeBytes += manifest.length(); + } + + return totalSizeBytes; + } + + private boolean hasFileCounts(ManifestFile manifest) { + return manifest.addedFilesCount() != null + && manifest.existingFilesCount() != null + && manifest.deletedFilesCount() != null; + } + + private void replaceManifests( + Iterable deletedManifests, Iterable addedManifests) { + try { + org.apache.iceberg.RewriteManifests rewriteManifests = table.rewriteManifests(); + deletedManifests.forEach(rewriteManifests::deleteManifest); + addedManifests.forEach(rewriteManifests::addManifest); + commit(rewriteManifests); + + if (shouldStageManifests) { + // delete new manifests as they were rewritten before the commit + deleteFiles(Iterables.transform(addedManifests, ManifestFile::path)); + } + } catch (CommitStateUnknownException commitStateUnknownException) { + // don't clean up added manifest files, because they may have been successfully committed. + throw commitStateUnknownException; + } catch (Exception e) { + if (e instanceof CleanableFailure) { + // delete all new manifests because the rewrite failed + deleteFiles(Iterables.transform(addedManifests, ManifestFile::path)); + } + + throw e; + } + } + + private void deleteFiles(Iterable locations) { + Iterable files = + Iterables.transform(locations, location -> new FileInfo(location, MANIFEST)); + if (table.io() instanceof SupportsBulkOperations) { + deleteFiles((SupportsBulkOperations) table.io(), files.iterator()); + } else { + deleteFiles( + ThreadPools.getWorkerPool(), file -> table.io().deleteFile(file), files.iterator()); + } + } + + private ManifestWriterFactory manifestWriters() { + return new ManifestWriterFactory( + sparkContext().broadcast(SerializableTableWithSize.copyOf(table)), + formatVersion, + spec.specId(), + outputLocation, + // allow the actual size of manifests to be 20% higher as the estimation is not precise + (long) (1.2 * targetManifestSizeBytes)); + } + + private static class WriteDataManifests extends WriteManifests { + + WriteDataManifests( + ManifestWriterFactory manifestWriters, + Types.StructType combinedPartitionType, + Types.StructType partitionType, + StructType sparkFileType) { + super(manifestWriters, combinedPartitionType, partitionType, sparkFileType); + } + + @Override + protected SparkDataFile newFileWrapper() { + return new SparkDataFile(combinedFileType(), fileType(), sparkFileType()); + } + + @Override + protected RollingManifestWriter newManifestWriter() { + return writers().newRollingManifestWriter(); + } + } + + private static class WriteDeleteManifests extends WriteManifests { + + WriteDeleteManifests( + ManifestWriterFactory manifestWriters, + Types.StructType combinedFileType, + Types.StructType fileType, + StructType sparkFileType) { + super(manifestWriters, combinedFileType, fileType, sparkFileType); + } + + @Override + protected SparkDeleteFile newFileWrapper() { + return new SparkDeleteFile(combinedFileType(), fileType(), sparkFileType()); + } + + @Override + protected RollingManifestWriter newManifestWriter() { + return writers().newRollingDeleteManifestWriter(); + } + } + + private abstract static class WriteManifests> + implements MapPartitionsFunction { + + private static final Encoder MANIFEST_ENCODER = + Encoders.javaSerialization(ManifestFile.class); + + private final ManifestWriterFactory writers; + private final Types.StructType combinedFileType; + private final Types.StructType fileType; + private final StructType sparkFileType; + + WriteManifests( + ManifestWriterFactory writers, + Types.StructType combinedFileType, + Types.StructType fileType, + StructType sparkFileType) { + this.writers = writers; + this.combinedFileType = combinedFileType; + this.fileType = fileType; + this.sparkFileType = sparkFileType; + } + + protected abstract SparkContentFile newFileWrapper(); + + protected abstract RollingManifestWriter newManifestWriter(); + + public Dataset apply(Dataset input) { + return input.mapPartitions(this, MANIFEST_ENCODER); + } + + @Override + public Iterator call(Iterator rows) throws Exception { + SparkContentFile fileWrapper = newFileWrapper(); + RollingManifestWriter writer = newManifestWriter(); + + try { + while (rows.hasNext()) { + Row row = rows.next(); + long snapshotId = row.getLong(0); + long sequenceNumber = row.getLong(1); + Long fileSequenceNumber = row.isNullAt(2) ? null : row.getLong(2); + Row file = row.getStruct(3); + writer.existing(fileWrapper.wrap(file), snapshotId, sequenceNumber, fileSequenceNumber); + } + } finally { + writer.close(); + } + + return writer.toManifestFiles().iterator(); + } + + protected ManifestWriterFactory writers() { + return writers; + } + + protected Types.StructType combinedFileType() { + return combinedFileType; + } + + protected Types.StructType fileType() { + return fileType; + } + + protected StructType sparkFileType() { + return sparkFileType; + } + } + + private static class ManifestWriterFactory implements Serializable { + private final Broadcast

tableBroadcast; + private final int formatVersion; + private final int specId; + private final String outputLocation; + private final long maxManifestSizeBytes; + + ManifestWriterFactory( + Broadcast
tableBroadcast, + int formatVersion, + int specId, + String outputLocation, + long maxManifestSizeBytes) { + this.tableBroadcast = tableBroadcast; + this.formatVersion = formatVersion; + this.specId = specId; + this.outputLocation = outputLocation; + this.maxManifestSizeBytes = maxManifestSizeBytes; + } + + public RollingManifestWriter newRollingManifestWriter() { + return new RollingManifestWriter<>(this::newManifestWriter, maxManifestSizeBytes); + } + + private ManifestWriter newManifestWriter() { + return ManifestFiles.write(formatVersion, spec(), newOutputFile(), null); + } + + public RollingManifestWriter newRollingDeleteManifestWriter() { + return new RollingManifestWriter<>(this::newDeleteManifestWriter, maxManifestSizeBytes); + } + + private ManifestWriter newDeleteManifestWriter() { + return ManifestFiles.writeDeleteManifest(formatVersion, spec(), newOutputFile(), null); + } + + private PartitionSpec spec() { + return table().specs().get(specId); + } + + private OutputFile newOutputFile() { + return table().io().newOutputFile(newManifestLocation()); + } + + private String newManifestLocation() { + String fileName = FileFormat.AVRO.addExtension("optimized-m-" + UUID.randomUUID()); + Path filePath = new Path(outputLocation, fileName); + return filePath.toString(); + } + + private Table table() { + return tableBroadcast.value(); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/RewritePositionDeleteFilesSparkAction.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/RewritePositionDeleteFilesSparkAction.java new file mode 100644 index 000000000000..282222ae716f --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/RewritePositionDeleteFilesSparkAction.java @@ -0,0 +1,462 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.io.IOException; +import java.math.RoundingMode; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.MetadataTableType; +import org.apache.iceberg.MetadataTableUtils; +import org.apache.iceberg.Partitioning; +import org.apache.iceberg.PositionDeletesScanTask; +import org.apache.iceberg.PositionDeletesTable.PositionDeletesBatchScan; +import org.apache.iceberg.RewriteJobOrder; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.ImmutableRewritePositionDeleteFiles; +import org.apache.iceberg.actions.RewritePositionDeleteFiles; +import org.apache.iceberg.actions.RewritePositionDeletesCommitManager; +import org.apache.iceberg.actions.RewritePositionDeletesCommitManager.CommitService; +import org.apache.iceberg.actions.RewritePositionDeletesGroup; +import org.apache.iceberg.exceptions.CommitFailedException; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Queues; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.relocated.com.google.common.math.IntMath; +import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors; +import org.apache.iceberg.relocated.com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.apache.iceberg.types.Types.StructType; +import org.apache.iceberg.util.PartitionUtil; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.StructLikeMap; +import org.apache.iceberg.util.Tasks; +import org.apache.spark.sql.SparkSession; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Spark implementation of {@link RewritePositionDeleteFiles}. */ +public class RewritePositionDeleteFilesSparkAction + extends BaseSnapshotUpdateSparkAction + implements RewritePositionDeleteFiles { + + private static final Logger LOG = + LoggerFactory.getLogger(RewritePositionDeleteFilesSparkAction.class); + private static final Set VALID_OPTIONS = + ImmutableSet.of( + MAX_CONCURRENT_FILE_GROUP_REWRITES, + PARTIAL_PROGRESS_ENABLED, + PARTIAL_PROGRESS_MAX_COMMITS, + REWRITE_JOB_ORDER); + private static final Result EMPTY_RESULT = + ImmutableRewritePositionDeleteFiles.Result.builder().build(); + + private final Table table; + private final SparkBinPackPositionDeletesRewriter rewriter; + private Expression filter = Expressions.alwaysTrue(); + + private int maxConcurrentFileGroupRewrites; + private int maxCommits; + private boolean partialProgressEnabled; + private RewriteJobOrder rewriteJobOrder; + + RewritePositionDeleteFilesSparkAction(SparkSession spark, Table table) { + super(spark); + this.table = table; + this.rewriter = new SparkBinPackPositionDeletesRewriter(spark(), table); + } + + @Override + protected RewritePositionDeleteFilesSparkAction self() { + return this; + } + + @Override + public RewritePositionDeleteFilesSparkAction filter(Expression expression) { + filter = Expressions.and(filter, expression); + return this; + } + + @Override + public RewritePositionDeleteFiles.Result execute() { + if (table.currentSnapshot() == null) { + LOG.info("Nothing found to rewrite in empty table {}", table.name()); + return EMPTY_RESULT; + } + + validateAndInitOptions(); + + StructLikeMap>> fileGroupsByPartition = planFileGroups(); + RewriteExecutionContext ctx = new RewriteExecutionContext(fileGroupsByPartition); + + if (ctx.totalGroupCount() == 0) { + LOG.info("Nothing found to rewrite in {}", table.name()); + return EMPTY_RESULT; + } + + Stream groupStream = toGroupStream(ctx, fileGroupsByPartition); + + if (partialProgressEnabled) { + return doExecuteWithPartialProgress(ctx, groupStream, commitManager()); + } else { + return doExecute(ctx, groupStream, commitManager()); + } + } + + private StructLikeMap>> planFileGroups() { + Table deletesTable = + MetadataTableUtils.createMetadataTableInstance(table, MetadataTableType.POSITION_DELETES); + CloseableIterable fileTasks = planFiles(deletesTable); + + try { + StructType partitionType = Partitioning.partitionType(deletesTable); + StructLikeMap> fileTasksByPartition = + groupByPartition(partitionType, fileTasks); + return fileGroupsByPartition(fileTasksByPartition); + } finally { + try { + fileTasks.close(); + } catch (IOException io) { + LOG.error("Cannot properly close file iterable while planning for rewrite", io); + } + } + } + + private CloseableIterable planFiles(Table deletesTable) { + PositionDeletesBatchScan scan = (PositionDeletesBatchScan) deletesTable.newBatchScan(); + return CloseableIterable.transform( + scan.baseTableFilter(filter).ignoreResiduals().planFiles(), + task -> (PositionDeletesScanTask) task); + } + + private StructLikeMap> groupByPartition( + StructType partitionType, Iterable tasks) { + StructLikeMap> filesByPartition = + StructLikeMap.create(partitionType); + + for (PositionDeletesScanTask task : tasks) { + StructLike coerced = coercePartition(task, partitionType); + + List partitionTasks = filesByPartition.get(coerced); + if (partitionTasks == null) { + partitionTasks = Lists.newArrayList(); + } + partitionTasks.add(task); + filesByPartition.put(coerced, partitionTasks); + } + + return filesByPartition; + } + + private StructLikeMap>> fileGroupsByPartition( + StructLikeMap> filesByPartition) { + return filesByPartition.transformValues(this::planFileGroups); + } + + private List> planFileGroups(List tasks) { + return ImmutableList.copyOf(rewriter.planFileGroups(tasks)); + } + + private RewritePositionDeletesGroup rewriteDeleteFiles( + RewriteExecutionContext ctx, RewritePositionDeletesGroup fileGroup) { + String desc = jobDesc(fileGroup, ctx); + Set addedFiles = + withJobGroupInfo( + newJobGroupInfo("REWRITE-POSITION-DELETES", desc), + () -> rewriter.rewrite(fileGroup.tasks())); + + fileGroup.setOutputFiles(addedFiles); + LOG.info("Rewrite position deletes ready to be committed - {}", desc); + return fileGroup; + } + + private ExecutorService rewriteService() { + return MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) + Executors.newFixedThreadPool( + maxConcurrentFileGroupRewrites, + new ThreadFactoryBuilder() + .setNameFormat("Rewrite-Position-Delete-Service-%d") + .build())); + } + + private RewritePositionDeletesCommitManager commitManager() { + return new RewritePositionDeletesCommitManager(table, commitSummary()); + } + + private Result doExecute( + RewriteExecutionContext ctx, + Stream groupStream, + RewritePositionDeletesCommitManager commitManager) { + ExecutorService rewriteService = rewriteService(); + + ConcurrentLinkedQueue rewrittenGroups = + Queues.newConcurrentLinkedQueue(); + + Tasks.Builder rewriteTaskBuilder = + Tasks.foreach(groupStream) + .executeWith(rewriteService) + .stopOnFailure() + .noRetry() + .onFailure( + (fileGroup, exception) -> + LOG.warn( + "Failure during rewrite process for group {}", + fileGroup.info(), + exception)); + + try { + rewriteTaskBuilder.run(fileGroup -> rewrittenGroups.add(rewriteDeleteFiles(ctx, fileGroup))); + } catch (Exception e) { + // At least one rewrite group failed, clean up all completed rewrites + LOG.error( + "Cannot complete rewrite, {} is not enabled and one of the file set groups failed to " + + "be rewritten. This error occurred during the writing of new files, not during the commit process. This " + + "indicates something is wrong that doesn't involve conflicts with other Iceberg operations. Enabling " + + "{} may help in this case but the root cause should be investigated. Cleaning up {} groups which finished " + + "being written.", + PARTIAL_PROGRESS_ENABLED, + PARTIAL_PROGRESS_ENABLED, + rewrittenGroups.size(), + e); + + Tasks.foreach(rewrittenGroups).suppressFailureWhenFinished().run(commitManager::abort); + throw e; + } finally { + rewriteService.shutdown(); + } + + try { + commitManager.commitOrClean(Sets.newHashSet(rewrittenGroups)); + } catch (ValidationException | CommitFailedException e) { + String errorMessage = + String.format( + "Cannot commit rewrite because of a ValidationException or CommitFailedException. This usually means that " + + "this rewrite has conflicted with another concurrent Iceberg operation. To reduce the likelihood of " + + "conflicts, set %s which will break up the rewrite into multiple smaller commits controlled by %s. " + + "Separate smaller rewrite commits can succeed independently while any commits that conflict with " + + "another Iceberg operation will be ignored. This mode will create additional snapshots in the table " + + "history, one for each commit.", + PARTIAL_PROGRESS_ENABLED, PARTIAL_PROGRESS_MAX_COMMITS); + throw new RuntimeException(errorMessage, e); + } + + List rewriteResults = + rewrittenGroups.stream() + .map(RewritePositionDeletesGroup::asResult) + .collect(Collectors.toList()); + + return ImmutableRewritePositionDeleteFiles.Result.builder() + .rewriteResults(rewriteResults) + .build(); + } + + private Result doExecuteWithPartialProgress( + RewriteExecutionContext ctx, + Stream groupStream, + RewritePositionDeletesCommitManager commitManager) { + ExecutorService rewriteService = rewriteService(); + + // start commit service + int groupsPerCommit = IntMath.divide(ctx.totalGroupCount(), maxCommits, RoundingMode.CEILING); + CommitService commitService = commitManager.service(groupsPerCommit); + commitService.start(); + + // start rewrite tasks + Tasks.foreach(groupStream) + .suppressFailureWhenFinished() + .executeWith(rewriteService) + .noRetry() + .onFailure( + (fileGroup, exception) -> + LOG.error("Failure during rewrite group {}", fileGroup.info(), exception)) + .run(fileGroup -> commitService.offer(rewriteDeleteFiles(ctx, fileGroup))); + rewriteService.shutdown(); + + // stop commit service + commitService.close(); + List commitResults = commitService.results(); + if (commitResults.isEmpty()) { + LOG.error( + "{} is true but no rewrite commits succeeded. Check the logs to determine why the individual " + + "commits failed. If this is persistent it may help to increase {} which will break the rewrite operation " + + "into smaller commits.", + PARTIAL_PROGRESS_ENABLED, + PARTIAL_PROGRESS_MAX_COMMITS); + } + + List rewriteResults = + commitResults.stream() + .map(RewritePositionDeletesGroup::asResult) + .collect(Collectors.toList()); + return ImmutableRewritePositionDeleteFiles.Result.builder() + .rewriteResults(rewriteResults) + .build(); + } + + private Stream toGroupStream( + RewriteExecutionContext ctx, + Map>> groupsByPartition) { + return groupsByPartition.entrySet().stream() + .filter(e -> !e.getValue().isEmpty()) + .flatMap( + e -> { + StructLike partition = e.getKey(); + List> scanGroups = e.getValue(); + return scanGroups.stream().map(tasks -> newRewriteGroup(ctx, partition, tasks)); + }) + .sorted(RewritePositionDeletesGroup.comparator(rewriteJobOrder)); + } + + private RewritePositionDeletesGroup newRewriteGroup( + RewriteExecutionContext ctx, StructLike partition, List tasks) { + int globalIndex = ctx.currentGlobalIndex(); + int partitionIndex = ctx.currentPartitionIndex(partition); + FileGroupInfo info = + ImmutableRewritePositionDeleteFiles.FileGroupInfo.builder() + .globalIndex(globalIndex) + .partitionIndex(partitionIndex) + .partition(partition) + .build(); + return new RewritePositionDeletesGroup(info, tasks); + } + + private void validateAndInitOptions() { + Set validOptions = Sets.newHashSet(rewriter.validOptions()); + validOptions.addAll(VALID_OPTIONS); + + Set invalidKeys = Sets.newHashSet(options().keySet()); + invalidKeys.removeAll(validOptions); + + Preconditions.checkArgument( + invalidKeys.isEmpty(), + "Cannot use options %s, they are not supported by the action or the rewriter %s", + invalidKeys, + rewriter.description()); + + rewriter.init(options()); + + this.maxConcurrentFileGroupRewrites = + PropertyUtil.propertyAsInt( + options(), + MAX_CONCURRENT_FILE_GROUP_REWRITES, + MAX_CONCURRENT_FILE_GROUP_REWRITES_DEFAULT); + + this.maxCommits = + PropertyUtil.propertyAsInt( + options(), PARTIAL_PROGRESS_MAX_COMMITS, PARTIAL_PROGRESS_MAX_COMMITS_DEFAULT); + + this.partialProgressEnabled = + PropertyUtil.propertyAsBoolean( + options(), PARTIAL_PROGRESS_ENABLED, PARTIAL_PROGRESS_ENABLED_DEFAULT); + + this.rewriteJobOrder = + RewriteJobOrder.fromName( + PropertyUtil.propertyAsString(options(), REWRITE_JOB_ORDER, REWRITE_JOB_ORDER_DEFAULT)); + + Preconditions.checkArgument( + maxConcurrentFileGroupRewrites >= 1, + "Cannot set %s to %s, the value must be positive.", + MAX_CONCURRENT_FILE_GROUP_REWRITES, + maxConcurrentFileGroupRewrites); + + Preconditions.checkArgument( + !partialProgressEnabled || maxCommits > 0, + "Cannot set %s to %s, the value must be positive when %s is true", + PARTIAL_PROGRESS_MAX_COMMITS, + maxCommits, + PARTIAL_PROGRESS_ENABLED); + } + + private String jobDesc(RewritePositionDeletesGroup group, RewriteExecutionContext ctx) { + StructLike partition = group.info().partition(); + if (partition.size() > 0) { + return String.format( + "Rewriting %d position delete files (%s, file group %d/%d, %s (%d/%d)) in %s", + group.rewrittenDeleteFiles().size(), + rewriter.description(), + group.info().globalIndex(), + ctx.totalGroupCount(), + partition, + group.info().partitionIndex(), + ctx.groupsInPartition(partition), + table.name()); + } else { + return String.format( + "Rewriting %d position files (%s, file group %d/%d) in %s", + group.rewrittenDeleteFiles().size(), + rewriter.description(), + group.info().globalIndex(), + ctx.totalGroupCount(), + table.name()); + } + } + + static class RewriteExecutionContext { + private final StructLikeMap numGroupsByPartition; + private final int totalGroupCount; + private final Map partitionIndexMap; + private final AtomicInteger groupIndex; + + RewriteExecutionContext( + StructLikeMap>> fileTasksByPartition) { + this.numGroupsByPartition = fileTasksByPartition.transformValues(List::size); + this.totalGroupCount = numGroupsByPartition.values().stream().reduce(Integer::sum).orElse(0); + this.partitionIndexMap = Maps.newConcurrentMap(); + this.groupIndex = new AtomicInteger(1); + } + + public int currentGlobalIndex() { + return groupIndex.getAndIncrement(); + } + + public int currentPartitionIndex(StructLike partition) { + return partitionIndexMap.merge(partition, 1, Integer::sum); + } + + public int groupsInPartition(StructLike partition) { + return numGroupsByPartition.get(partition); + } + + public int totalGroupCount() { + return totalGroupCount; + } + } + + private StructLike coercePartition(PositionDeletesScanTask task, StructType partitionType) { + return PartitionUtil.coercePartition(partitionType, task.spec(), task.partition()); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/SetAccumulator.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/SetAccumulator.java new file mode 100644 index 000000000000..745169fc1efd --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/SetAccumulator.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.util.Collections; +import java.util.Set; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.spark.util.AccumulatorV2; + +public class SetAccumulator extends AccumulatorV2> { + + private final Set set = Collections.synchronizedSet(Sets.newHashSet()); + + @Override + public boolean isZero() { + return set.isEmpty(); + } + + @Override + public AccumulatorV2> copy() { + SetAccumulator newAccumulator = new SetAccumulator<>(); + newAccumulator.set.addAll(set); + return newAccumulator; + } + + @Override + public void reset() { + set.clear(); + } + + @Override + public void add(T v) { + set.add(v); + } + + @Override + public void merge(AccumulatorV2> other) { + set.addAll(other.value()); + } + + @Override + public Set value() { + return set; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/SnapshotTableSparkAction.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/SnapshotTableSparkAction.java new file mode 100644 index 000000000000..5f7f408cb099 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/SnapshotTableSparkAction.java @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.util.Map; +import java.util.concurrent.ExecutorService; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.actions.ImmutableSnapshotTable; +import org.apache.iceberg.actions.SnapshotTable; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.JobGroupInfo; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.Spark3Util.CatalogAndIdentifier; +import org.apache.iceberg.spark.SparkTableUtil; +import org.apache.iceberg.spark.source.StagedSparkTable; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.TableIdentifier; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.StagingTableCatalog; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.collection.JavaConverters; + +/** + * Creates a new Iceberg table based on a source Spark table. The new Iceberg table will have a + * different data and metadata directory allowing it to exist independently of the source table. + */ +public class SnapshotTableSparkAction extends BaseTableCreationSparkAction + implements SnapshotTable { + + private static final Logger LOG = LoggerFactory.getLogger(SnapshotTableSparkAction.class); + + private StagingTableCatalog destCatalog; + private Identifier destTableIdent; + private String destTableLocation = null; + private ExecutorService executorService; + + SnapshotTableSparkAction( + SparkSession spark, CatalogPlugin sourceCatalog, Identifier sourceTableIdent) { + super(spark, sourceCatalog, sourceTableIdent); + } + + @Override + protected SnapshotTableSparkAction self() { + return this; + } + + @Override + protected StagingTableCatalog destCatalog() { + return destCatalog; + } + + @Override + protected Identifier destTableIdent() { + return destTableIdent; + } + + @Override + public SnapshotTableSparkAction as(String ident) { + String ctx = "snapshot destination"; + CatalogPlugin defaultCatalog = spark().sessionState().catalogManager().currentCatalog(); + CatalogAndIdentifier catalogAndIdent = + Spark3Util.catalogAndIdentifier(ctx, spark(), ident, defaultCatalog); + this.destCatalog = checkDestinationCatalog(catalogAndIdent.catalog()); + this.destTableIdent = catalogAndIdent.identifier(); + return this; + } + + @Override + public SnapshotTableSparkAction tableProperties(Map properties) { + setProperties(properties); + return this; + } + + @Override + public SnapshotTableSparkAction tableProperty(String property, String value) { + setProperty(property, value); + return this; + } + + @Override + public SnapshotTableSparkAction executeWith(ExecutorService service) { + this.executorService = service; + return this; + } + + @Override + public SnapshotTable.Result execute() { + String desc = String.format("Snapshotting table %s as %s", sourceTableIdent(), destTableIdent); + JobGroupInfo info = newJobGroupInfo("SNAPSHOT-TABLE", desc); + return withJobGroupInfo(info, this::doExecute); + } + + private SnapshotTable.Result doExecute() { + Preconditions.checkArgument( + destCatalog() != null && destTableIdent() != null, + "The destination catalog and identifier cannot be null. " + + "Make sure to configure the action with a valid destination table identifier via the `as` method."); + + LOG.info( + "Staging a new Iceberg table {} as a snapshot of {}", destTableIdent(), sourceTableIdent()); + StagedSparkTable stagedTable = stageDestTable(); + Table icebergTable = stagedTable.table(); + + // TODO: Check the dest table location does not overlap with the source table location + + boolean threw = true; + try { + LOG.info("Ensuring {} has a valid name mapping", destTableIdent()); + ensureNameMappingPresent(icebergTable); + + TableIdentifier v1TableIdent = v1SourceTable().identifier(); + String stagingLocation = getMetadataLocation(icebergTable); + LOG.info("Generating Iceberg metadata for {} in {}", destTableIdent(), stagingLocation); + SparkTableUtil.importSparkTable( + spark(), v1TableIdent, icebergTable, stagingLocation, executorService); + + LOG.info("Committing staged changes to {}", destTableIdent()); + stagedTable.commitStagedChanges(); + threw = false; + } finally { + if (threw) { + LOG.error("Error when populating the staged table with metadata, aborting changes"); + + try { + stagedTable.abortStagedChanges(); + } catch (Exception abortException) { + LOG.error("Cannot abort staged changes", abortException); + } + } + } + + Snapshot snapshot = icebergTable.currentSnapshot(); + long importedDataFilesCount = + Long.parseLong(snapshot.summary().get(SnapshotSummary.TOTAL_DATA_FILES_PROP)); + LOG.info( + "Successfully loaded Iceberg metadata for {} files to {}", + importedDataFilesCount, + destTableIdent()); + return ImmutableSnapshotTable.Result.builder() + .importedDataFilesCount(importedDataFilesCount) + .build(); + } + + @Override + protected Map destTableProps() { + Map properties = Maps.newHashMap(); + + // copy over relevant source table props + properties.putAll(JavaConverters.mapAsJavaMapConverter(v1SourceTable().properties()).asJava()); + EXCLUDED_PROPERTIES.forEach(properties::remove); + + // remove any possible location properties from origin properties + properties.remove(LOCATION); + properties.remove(TableProperties.WRITE_METADATA_LOCATION); + properties.remove(TableProperties.WRITE_FOLDER_STORAGE_LOCATION); + properties.remove(TableProperties.OBJECT_STORE_PATH); + properties.remove(TableProperties.WRITE_DATA_LOCATION); + + // set default and user-provided props + properties.put(TableCatalog.PROP_PROVIDER, "iceberg"); + properties.putAll(additionalProperties()); + + // make sure we mark this table as a snapshot table + properties.put(TableProperties.GC_ENABLED, "false"); + properties.put("snapshot", "true"); + + // set the destination table location if provided + if (destTableLocation != null) { + properties.put(LOCATION, destTableLocation); + } + + return properties; + } + + @Override + protected TableCatalog checkSourceCatalog(CatalogPlugin catalog) { + // currently the import code relies on being able to look up the table in the session catalog + Preconditions.checkArgument( + catalog.name().equalsIgnoreCase("spark_catalog"), + "Cannot snapshot a table that isn't in the session catalog (i.e. spark_catalog). " + + "Found source catalog: %s.", + catalog.name()); + + Preconditions.checkArgument( + catalog instanceof TableCatalog, + "Cannot snapshot as catalog %s of class %s in not a table catalog", + catalog.name(), + catalog.getClass().getName()); + + return (TableCatalog) catalog; + } + + @Override + public SnapshotTableSparkAction tableLocation(String location) { + Preconditions.checkArgument( + !sourceTableLocation().equals(location), + "The snapshot table location cannot be same as the source table location. " + + "This would mix snapshot table files with original table files."); + this.destTableLocation = location; + return this; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/SparkActions.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/SparkActions.java new file mode 100644 index 000000000000..ba9fa2e7b4db --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/SparkActions.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.ActionsProvider; +import org.apache.iceberg.actions.ComputeTableStats; +import org.apache.iceberg.actions.RemoveDanglingDeleteFiles; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.Spark3Util.CatalogAndIdentifier; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; + +/** + * An implementation of {@link ActionsProvider} for Spark. + * + *

This class is the primary API for interacting with actions in Spark that users should use to + * instantiate particular actions. + */ +public class SparkActions implements ActionsProvider { + + private final SparkSession spark; + + private SparkActions(SparkSession spark) { + this.spark = spark; + } + + public static SparkActions get(SparkSession spark) { + return new SparkActions(spark); + } + + public static SparkActions get() { + return new SparkActions(SparkSession.active()); + } + + @Override + public SnapshotTableSparkAction snapshotTable(String tableIdent) { + String ctx = "snapshot source"; + CatalogPlugin defaultCatalog = spark.sessionState().catalogManager().currentCatalog(); + CatalogAndIdentifier catalogAndIdent = + Spark3Util.catalogAndIdentifier(ctx, spark, tableIdent, defaultCatalog); + return new SnapshotTableSparkAction( + spark, catalogAndIdent.catalog(), catalogAndIdent.identifier()); + } + + @Override + public MigrateTableSparkAction migrateTable(String tableIdent) { + String ctx = "migrate target"; + CatalogPlugin defaultCatalog = spark.sessionState().catalogManager().currentCatalog(); + CatalogAndIdentifier catalogAndIdent = + Spark3Util.catalogAndIdentifier(ctx, spark, tableIdent, defaultCatalog); + return new MigrateTableSparkAction( + spark, catalogAndIdent.catalog(), catalogAndIdent.identifier()); + } + + @Override + public RewriteDataFilesSparkAction rewriteDataFiles(Table table) { + return new RewriteDataFilesSparkAction(spark, table); + } + + @Override + public DeleteOrphanFilesSparkAction deleteOrphanFiles(Table table) { + return new DeleteOrphanFilesSparkAction(spark, table); + } + + @Override + public RewriteManifestsSparkAction rewriteManifests(Table table) { + return new RewriteManifestsSparkAction(spark, table); + } + + @Override + public ExpireSnapshotsSparkAction expireSnapshots(Table table) { + return new ExpireSnapshotsSparkAction(spark, table); + } + + @Override + public DeleteReachableFilesSparkAction deleteReachableFiles(String metadataLocation) { + return new DeleteReachableFilesSparkAction(spark, metadataLocation); + } + + @Override + public RewritePositionDeleteFilesSparkAction rewritePositionDeletes(Table table) { + return new RewritePositionDeleteFilesSparkAction(spark, table); + } + + @Override + public ComputeTableStats computeTableStats(Table table) { + return new ComputeTableStatsSparkAction(spark, table); + } + + @Override + public RemoveDanglingDeleteFiles removeDanglingDeleteFiles(Table table) { + return new RemoveDanglingDeletesSparkAction(spark, table); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/SparkBinPackDataRewriter.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/SparkBinPackDataRewriter.java new file mode 100644 index 000000000000..d256bf2794e2 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/SparkBinPackDataRewriter.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.util.List; +import org.apache.iceberg.DistributionMode; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Table; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; + +class SparkBinPackDataRewriter extends SparkSizeBasedDataRewriter { + + SparkBinPackDataRewriter(SparkSession spark, Table table) { + super(spark, table); + } + + @Override + public String description() { + return "BIN-PACK"; + } + + @Override + protected void doRewrite(String groupId, List group) { + // read the files packing them into splits of the required size + Dataset scanDF = + spark() + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, groupId) + .option(SparkReadOptions.SPLIT_SIZE, splitSize(inputSize(group))) + .option(SparkReadOptions.FILE_OPEN_COST, "0") + .load(groupId); + + // write the packed data into new files where each split becomes a new file + scanDF + .write() + .format("iceberg") + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, groupId) + .option(SparkWriteOptions.TARGET_FILE_SIZE_BYTES, writeMaxFileSize()) + .option(SparkWriteOptions.DISTRIBUTION_MODE, distributionMode(group).modeName()) + .option(SparkWriteOptions.OUTPUT_SPEC_ID, outputSpecId()) + .mode("append") + .save(groupId); + } + + // invoke a shuffle if the original spec does not match the output spec + private DistributionMode distributionMode(List group) { + boolean requiresRepartition = !group.get(0).spec().equals(outputSpec()); + return requiresRepartition ? DistributionMode.RANGE : DistributionMode.NONE; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/SparkBinPackPositionDeletesRewriter.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/SparkBinPackPositionDeletesRewriter.java new file mode 100644 index 000000000000..5afd724aad88 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/SparkBinPackPositionDeletesRewriter.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.apache.iceberg.MetadataTableType.POSITION_DELETES; +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.lit; + +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.UUID; +import java.util.stream.IntStream; +import org.apache.iceberg.DataFilesTable; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.MetadataTableType; +import org.apache.iceberg.MetadataTableUtils; +import org.apache.iceberg.PositionDeletesScanTask; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.SizeBasedPositionDeletesRewriter; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.PositionDeletesRewriteCoordinator; +import org.apache.iceberg.spark.ScanTaskSetManager; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkTableCache; +import org.apache.iceberg.spark.SparkTableUtil; +import org.apache.iceberg.spark.SparkValueConverter; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.internal.SQLConf; + +class SparkBinPackPositionDeletesRewriter extends SizeBasedPositionDeletesRewriter { + + private final SparkSession spark; + private final SparkTableCache tableCache = SparkTableCache.get(); + private final ScanTaskSetManager taskSetManager = ScanTaskSetManager.get(); + private final PositionDeletesRewriteCoordinator coordinator = + PositionDeletesRewriteCoordinator.get(); + + SparkBinPackPositionDeletesRewriter(SparkSession spark, Table table) { + super(table); + // Disable Adaptive Query Execution as this may change the output partitioning of our write + this.spark = spark.cloneSession(); + this.spark.conf().set(SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), false); + } + + @Override + public String description() { + return "BIN-PACK"; + } + + @Override + public Set rewrite(List group) { + String groupId = UUID.randomUUID().toString(); + Table deletesTable = MetadataTableUtils.createMetadataTableInstance(table(), POSITION_DELETES); + try { + tableCache.add(groupId, deletesTable); + taskSetManager.stageTasks(deletesTable, groupId, group); + + doRewrite(groupId, group); + + return coordinator.fetchNewFiles(deletesTable, groupId); + } finally { + tableCache.remove(groupId); + taskSetManager.removeTasks(deletesTable, groupId); + coordinator.clearRewrite(deletesTable, groupId); + } + } + + protected void doRewrite(String groupId, List group) { + // all position deletes are of the same partition, because they are in same file group + Preconditions.checkArgument(!group.isEmpty(), "Empty group"); + Types.StructType partitionType = group.get(0).spec().partitionType(); + StructLike partition = group.get(0).partition(); + + // read the deletes packing them into splits of the required size + Dataset posDeletes = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, groupId) + .option(SparkReadOptions.SPLIT_SIZE, splitSize(inputSize(group))) + .option(SparkReadOptions.FILE_OPEN_COST, "0") + .load(groupId); + + // keep only valid position deletes + Dataset dataFiles = dataFiles(partitionType, partition); + Column joinCond = posDeletes.col("file_path").equalTo(dataFiles.col("file_path")); + Dataset validDeletes = posDeletes.join(dataFiles, joinCond, "leftsemi"); + + // write the packed deletes into new files where each split becomes a new file + validDeletes + .sortWithinPartitions("file_path", "pos") + .write() + .format("iceberg") + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, groupId) + .option(SparkWriteOptions.TARGET_DELETE_FILE_SIZE_BYTES, writeMaxFileSize()) + .mode("append") + .save(groupId); + } + + /** Returns entries of {@link DataFilesTable} of specified partition */ + private Dataset dataFiles(Types.StructType partitionType, StructLike partition) { + List fields = partitionType.fields(); + Optional condition = + IntStream.range(0, fields.size()) + .mapToObj( + i -> { + Type type = fields.get(i).type(); + Object value = partition.get(i, type.typeId().javaClass()); + Object convertedValue = SparkValueConverter.convertToSpark(type, value); + Column col = col("partition.`" + fields.get(i).name() + "`"); + return col.eqNullSafe(lit(convertedValue)); + }) + .reduce(Column::and); + if (condition.isPresent()) { + return SparkTableUtil.loadMetadataTable(spark, table(), MetadataTableType.DATA_FILES) + .filter(condition.get()); + } else { + return SparkTableUtil.loadMetadataTable(spark, table(), MetadataTableType.DATA_FILES); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/SparkShufflingDataRewriter.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/SparkShufflingDataRewriter.java new file mode 100644 index 000000000000..ce572c6486cc --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/SparkShufflingDataRewriter.java @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkFunctionCatalog; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.SortOrderUtil; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.plans.logical.OrderAwareCoalesce; +import org.apache.spark.sql.catalyst.plans.logical.OrderAwareCoalescer; +import org.apache.spark.sql.connector.distributions.Distribution; +import org.apache.spark.sql.connector.distributions.Distributions; +import org.apache.spark.sql.connector.distributions.OrderedDistribution; +import org.apache.spark.sql.connector.expressions.SortOrder; +import org.apache.spark.sql.connector.write.RequiresDistributionAndOrdering; +import org.apache.spark.sql.execution.datasources.v2.DistributionAndOrderingUtils$; +import scala.Option; + +abstract class SparkShufflingDataRewriter extends SparkSizeBasedDataRewriter { + + /** + * The number of shuffle partitions and consequently the number of output files created by the + * Spark sort is based on the size of the input data files used in this file rewriter. Due to + * compression, the disk file sizes may not accurately represent the size of files in the output. + * This parameter lets the user adjust the file size used for estimating actual output data size. + * A factor greater than 1.0 would generate more files than we would expect based on the on-disk + * file size. A value less than 1.0 would create fewer files than we would expect based on the + * on-disk size. + */ + public static final String COMPRESSION_FACTOR = "compression-factor"; + + public static final double COMPRESSION_FACTOR_DEFAULT = 1.0; + + /** + * The number of shuffle partitions to use for each output file. By default, this file rewriter + * assumes each shuffle partition would become a separate output file. Attempting to generate + * large output files of 512 MB or higher may strain the memory resources of the cluster as such + * rewrites would require lots of Spark memory. This parameter can be used to further divide up + * the data which will end up in a single file. For example, if the target file size is 2 GB, but + * the cluster can only handle shuffles of 512 MB, this parameter could be set to 4. Iceberg will + * use a custom coalesce operation to stitch these sorted partitions back together into a single + * sorted file. + * + *

Note using this parameter requires enabling Iceberg Spark session extensions. + */ + public static final String SHUFFLE_PARTITIONS_PER_FILE = "shuffle-partitions-per-file"; + + public static final int SHUFFLE_PARTITIONS_PER_FILE_DEFAULT = 1; + + private double compressionFactor; + private int numShufflePartitionsPerFile; + + protected SparkShufflingDataRewriter(SparkSession spark, Table table) { + super(spark, table); + } + + protected abstract org.apache.iceberg.SortOrder sortOrder(); + + /** + * Retrieves and returns the schema for the rewrite using the current table schema. + * + *

The schema with all columns required for correctly sorting the table. This may include + * additional computed columns which are not written to the table but are used for sorting. + */ + protected Schema sortSchema() { + return table().schema(); + } + + protected abstract Dataset sortedDF( + Dataset df, Function, Dataset> sortFunc); + + @Override + public Set validOptions() { + return ImmutableSet.builder() + .addAll(super.validOptions()) + .add(COMPRESSION_FACTOR) + .add(SHUFFLE_PARTITIONS_PER_FILE) + .build(); + } + + @Override + public void init(Map options) { + super.init(options); + this.compressionFactor = compressionFactor(options); + this.numShufflePartitionsPerFile = numShufflePartitionsPerFile(options); + } + + @Override + public void doRewrite(String groupId, List group) { + Dataset scanDF = + spark() + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, groupId) + .load(groupId); + + Dataset sortedDF = sortedDF(scanDF, sortFunction(group)); + + sortedDF + .write() + .format("iceberg") + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, groupId) + .option(SparkWriteOptions.TARGET_FILE_SIZE_BYTES, writeMaxFileSize()) + .option(SparkWriteOptions.USE_TABLE_DISTRIBUTION_AND_ORDERING, "false") + .option(SparkWriteOptions.OUTPUT_SPEC_ID, outputSpecId()) + .mode("append") + .save(groupId); + } + + private Function, Dataset> sortFunction(List group) { + SortOrder[] ordering = Spark3Util.toOrdering(outputSortOrder(group)); + int numShufflePartitions = numShufflePartitions(group); + return (df) -> transformPlan(df, plan -> sortPlan(plan, ordering, numShufflePartitions)); + } + + private LogicalPlan sortPlan(LogicalPlan plan, SortOrder[] ordering, int numShufflePartitions) { + SparkFunctionCatalog catalog = SparkFunctionCatalog.get(); + OrderedWrite write = new OrderedWrite(ordering, numShufflePartitions); + LogicalPlan sortPlan = + DistributionAndOrderingUtils$.MODULE$.prepareQuery(write, plan, Option.apply(catalog)); + + if (numShufflePartitionsPerFile == 1) { + return sortPlan; + } else { + OrderAwareCoalescer coalescer = new OrderAwareCoalescer(numShufflePartitionsPerFile); + int numOutputPartitions = numShufflePartitions / numShufflePartitionsPerFile; + return new OrderAwareCoalesce(numOutputPartitions, coalescer, sortPlan); + } + } + + private Dataset transformPlan(Dataset df, Function func) { + return new Dataset<>(spark(), func.apply(df.logicalPlan()), df.encoder()); + } + + private org.apache.iceberg.SortOrder outputSortOrder(List group) { + PartitionSpec spec = outputSpec(); + boolean requiresRepartitioning = !group.get(0).spec().equals(spec); + if (requiresRepartitioning) { + // build in the requirement for partition sorting into our sort order + // as the original spec for this group does not match the output spec + return SortOrderUtil.buildSortOrder(sortSchema(), spec, sortOrder()); + } else { + return sortOrder(); + } + } + + private int numShufflePartitions(List group) { + int numOutputFiles = (int) numOutputFiles((long) (inputSize(group) * compressionFactor)); + return Math.max(1, numOutputFiles * numShufflePartitionsPerFile); + } + + private double compressionFactor(Map options) { + double value = + PropertyUtil.propertyAsDouble(options, COMPRESSION_FACTOR, COMPRESSION_FACTOR_DEFAULT); + Preconditions.checkArgument( + value > 0, "'%s' is set to %s but must be > 0", COMPRESSION_FACTOR, value); + return value; + } + + private int numShufflePartitionsPerFile(Map options) { + int value = + PropertyUtil.propertyAsInt( + options, SHUFFLE_PARTITIONS_PER_FILE, SHUFFLE_PARTITIONS_PER_FILE_DEFAULT); + Preconditions.checkArgument( + value > 0, "'%s' is set to %s but must be > 0", SHUFFLE_PARTITIONS_PER_FILE, value); + Preconditions.checkArgument( + value == 1 || Spark3Util.extensionsEnabled(spark()), + "Using '%s' requires enabling Iceberg Spark session extensions", + SHUFFLE_PARTITIONS_PER_FILE); + return value; + } + + private static class OrderedWrite implements RequiresDistributionAndOrdering { + private final OrderedDistribution distribution; + private final SortOrder[] ordering; + private final int numShufflePartitions; + + OrderedWrite(SortOrder[] ordering, int numShufflePartitions) { + this.distribution = Distributions.ordered(ordering); + this.ordering = ordering; + this.numShufflePartitions = numShufflePartitions; + } + + @Override + public Distribution requiredDistribution() { + return distribution; + } + + @Override + public boolean distributionStrictlyRequired() { + return true; + } + + @Override + public int requiredNumPartitions() { + return numShufflePartitions; + } + + @Override + public SortOrder[] requiredOrdering() { + return ordering; + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/SparkSizeBasedDataRewriter.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/SparkSizeBasedDataRewriter.java new file mode 100644 index 000000000000..ae0e0d20dd4e --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/SparkSizeBasedDataRewriter.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.util.List; +import java.util.Set; +import java.util.UUID; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.SizeBasedDataRewriter; +import org.apache.iceberg.spark.FileRewriteCoordinator; +import org.apache.iceberg.spark.ScanTaskSetManager; +import org.apache.iceberg.spark.SparkTableCache; +import org.apache.spark.sql.SparkSession; + +abstract class SparkSizeBasedDataRewriter extends SizeBasedDataRewriter { + + private final SparkSession spark; + private final SparkTableCache tableCache = SparkTableCache.get(); + private final ScanTaskSetManager taskSetManager = ScanTaskSetManager.get(); + private final FileRewriteCoordinator coordinator = FileRewriteCoordinator.get(); + + SparkSizeBasedDataRewriter(SparkSession spark, Table table) { + super(table); + this.spark = spark; + } + + protected abstract void doRewrite(String groupId, List group); + + protected SparkSession spark() { + return spark; + } + + @Override + public Set rewrite(List group) { + String groupId = UUID.randomUUID().toString(); + try { + tableCache.add(groupId, table()); + taskSetManager.stageTasks(table(), groupId, group); + + doRewrite(groupId, group); + + return coordinator.fetchNewFiles(table(), groupId); + } finally { + tableCache.remove(groupId); + taskSetManager.removeTasks(table(), groupId); + coordinator.clearRewrite(table(), groupId); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/SparkSortDataRewriter.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/SparkSortDataRewriter.java new file mode 100644 index 000000000000..1f70d4d7ca9d --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/SparkSortDataRewriter.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.util.function.Function; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; + +class SparkSortDataRewriter extends SparkShufflingDataRewriter { + + private final SortOrder sortOrder; + + SparkSortDataRewriter(SparkSession spark, Table table) { + super(spark, table); + Preconditions.checkArgument( + table.sortOrder().isSorted(), + "Cannot sort data without a valid sort order, table '%s' is unsorted and no sort order is provided", + table.name()); + this.sortOrder = table.sortOrder(); + } + + SparkSortDataRewriter(SparkSession spark, Table table, SortOrder sortOrder) { + super(spark, table); + Preconditions.checkArgument( + sortOrder != null && sortOrder.isSorted(), + "Cannot sort data without a valid sort order, the provided sort order is null or empty"); + this.sortOrder = sortOrder; + } + + @Override + public String description() { + return "SORT"; + } + + @Override + protected SortOrder sortOrder() { + return sortOrder; + } + + @Override + protected Dataset sortedDF(Dataset df, Function, Dataset> sortFunc) { + return sortFunc.apply(df); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/SparkZOrderDataRewriter.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/SparkZOrderDataRewriter.java new file mode 100644 index 000000000000..cc4fb78ebd18 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/SparkZOrderDataRewriter.java @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.apache.spark.sql.functions.array; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; +import org.apache.iceberg.NullOrder; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortDirection; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkUtil; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.ZOrderByteUtils; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +class SparkZOrderDataRewriter extends SparkShufflingDataRewriter { + + private static final Logger LOG = LoggerFactory.getLogger(SparkZOrderDataRewriter.class); + + private static final String Z_COLUMN = "ICEZVALUE"; + private static final Schema Z_SCHEMA = + new Schema(Types.NestedField.required(0, Z_COLUMN, Types.BinaryType.get())); + private static final SortOrder Z_SORT_ORDER = + SortOrder.builderFor(Z_SCHEMA) + .sortBy(Z_COLUMN, SortDirection.ASC, NullOrder.NULLS_LAST) + .build(); + + /** + * Controls the amount of bytes interleaved in the ZOrder algorithm. Default is all bytes being + * interleaved. + */ + public static final String MAX_OUTPUT_SIZE = "max-output-size"; + + public static final int MAX_OUTPUT_SIZE_DEFAULT = Integer.MAX_VALUE; + + /** + * Controls the number of bytes considered from an input column of a type with variable length + * (String, Binary). + * + *

Default is to use the same size as primitives {@link ZOrderByteUtils#PRIMITIVE_BUFFER_SIZE}. + */ + public static final String VAR_LENGTH_CONTRIBUTION = "var-length-contribution"; + + public static final int VAR_LENGTH_CONTRIBUTION_DEFAULT = ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE; + + private final List zOrderColNames; + private int maxOutputSize; + private int varLengthContribution; + + SparkZOrderDataRewriter(SparkSession spark, Table table, List zOrderColNames) { + super(spark, table); + this.zOrderColNames = validZOrderColNames(spark, table, zOrderColNames); + } + + @Override + public String description() { + return "Z-ORDER"; + } + + @Override + public Set validOptions() { + return ImmutableSet.builder() + .addAll(super.validOptions()) + .add(MAX_OUTPUT_SIZE) + .add(VAR_LENGTH_CONTRIBUTION) + .build(); + } + + @Override + public void init(Map options) { + super.init(options); + this.maxOutputSize = maxOutputSize(options); + this.varLengthContribution = varLengthContribution(options); + } + + @Override + protected SortOrder sortOrder() { + return Z_SORT_ORDER; + } + + /** + * Overrides the sortSchema method to include columns from Z_SCHEMA. + * + *

This method generates a new Schema object which consists of columns from the original table + * schema and Z_SCHEMA. + */ + @Override + protected Schema sortSchema() { + return new Schema( + new ImmutableList.Builder() + .addAll(table().schema().columns()) + .addAll(Z_SCHEMA.columns()) + .build()); + } + + @Override + protected Dataset sortedDF(Dataset df, Function, Dataset> sortFunc) { + Dataset zValueDF = df.withColumn(Z_COLUMN, zValue(df)); + Dataset sortedDF = sortFunc.apply(zValueDF); + return sortedDF.drop(Z_COLUMN); + } + + private Column zValue(Dataset df) { + SparkZOrderUDF zOrderUDF = + new SparkZOrderUDF(zOrderColNames.size(), varLengthContribution, maxOutputSize); + + Column[] zOrderCols = + zOrderColNames.stream() + .map(df.schema()::apply) + .map(col -> zOrderUDF.sortedLexicographically(df.col(col.name()), col.dataType())) + .toArray(Column[]::new); + + return zOrderUDF.interleaveBytes(array(zOrderCols)); + } + + private int varLengthContribution(Map options) { + int value = + PropertyUtil.propertyAsInt( + options, VAR_LENGTH_CONTRIBUTION, VAR_LENGTH_CONTRIBUTION_DEFAULT); + Preconditions.checkArgument( + value > 0, + "Cannot use less than 1 byte for variable length types with ZOrder, '%s' was set to %s", + VAR_LENGTH_CONTRIBUTION, + value); + return value; + } + + private int maxOutputSize(Map options) { + int value = PropertyUtil.propertyAsInt(options, MAX_OUTPUT_SIZE, MAX_OUTPUT_SIZE_DEFAULT); + Preconditions.checkArgument( + value > 0, + "Cannot have the interleaved ZOrder value use less than 1 byte, '%s' was set to %s", + MAX_OUTPUT_SIZE, + value); + return value; + } + + private List validZOrderColNames( + SparkSession spark, Table table, List inputZOrderColNames) { + + Preconditions.checkArgument( + inputZOrderColNames != null && !inputZOrderColNames.isEmpty(), + "Cannot ZOrder when no columns are specified"); + + Schema schema = table.schema(); + Set identityPartitionFieldIds = table.spec().identitySourceIds(); + boolean caseSensitive = SparkUtil.caseSensitive(spark); + + List validZOrderColNames = Lists.newArrayList(); + + for (String colName : inputZOrderColNames) { + Types.NestedField field = + caseSensitive ? schema.findField(colName) : schema.caseInsensitiveFindField(colName); + Preconditions.checkArgument( + field != null, + "Cannot find column '%s' in table schema (case sensitive = %s): %s", + colName, + caseSensitive, + schema.asStruct()); + + if (identityPartitionFieldIds.contains(field.fieldId())) { + LOG.warn("Ignoring '{}' as such values are constant within a partition", colName); + } else { + validZOrderColNames.add(colName); + } + } + + Preconditions.checkArgument( + !validZOrderColNames.isEmpty(), + "Cannot ZOrder, all columns provided were identity partition columns and cannot be used"); + + return validZOrderColNames; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/SparkZOrderUDF.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/SparkZOrderUDF.java new file mode 100644 index 000000000000..db359fdd62fc --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/SparkZOrderUDF.java @@ -0,0 +1,325 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.Serializable; +import java.nio.ByteBuffer; +import java.nio.charset.CharsetEncoder; +import java.nio.charset.StandardCharsets; +import org.apache.iceberg.util.ZOrderByteUtils; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.expressions.UserDefinedFunction; +import org.apache.spark.sql.functions; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.BooleanType; +import org.apache.spark.sql.types.ByteType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.FloatType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.TimestampType; +import scala.collection.JavaConverters; +import scala.collection.Seq; + +class SparkZOrderUDF implements Serializable { + private static final byte[] PRIMITIVE_EMPTY = new byte[ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE]; + + /** + * Every Spark task runs iteratively on a rows in a single thread so ThreadLocal should protect + * from concurrent access to any of these structures. + */ + private transient ThreadLocal outputBuffer; + + private transient ThreadLocal inputHolder; + private transient ThreadLocal inputBuffers; + private transient ThreadLocal encoder; + + private final int numCols; + + private int inputCol = 0; + private int totalOutputBytes = 0; + private final int varTypeSize; + private final int maxOutputSize; + + SparkZOrderUDF(int numCols, int varTypeSize, int maxOutputSize) { + this.numCols = numCols; + this.varTypeSize = varTypeSize; + this.maxOutputSize = maxOutputSize; + } + + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + in.defaultReadObject(); + inputBuffers = ThreadLocal.withInitial(() -> new ByteBuffer[numCols]); + inputHolder = ThreadLocal.withInitial(() -> new byte[numCols][]); + outputBuffer = ThreadLocal.withInitial(() -> ByteBuffer.allocate(totalOutputBytes)); + encoder = ThreadLocal.withInitial(() -> StandardCharsets.UTF_8.newEncoder()); + } + + private ByteBuffer inputBuffer(int position, int size) { + ByteBuffer buffer = inputBuffers.get()[position]; + if (buffer == null) { + buffer = ByteBuffer.allocate(size); + inputBuffers.get()[position] = buffer; + } + return buffer; + } + + byte[] interleaveBits(Seq scalaBinary) { + byte[][] columnsBinary = JavaConverters.seqAsJavaList(scalaBinary).toArray(inputHolder.get()); + return ZOrderByteUtils.interleaveBits(columnsBinary, totalOutputBytes, outputBuffer.get()); + } + + private UserDefinedFunction tinyToOrderedBytesUDF() { + int position = inputCol; + UserDefinedFunction udf = + functions + .udf( + (Byte value) -> { + if (value == null) { + return PRIMITIVE_EMPTY; + } + return ZOrderByteUtils.tinyintToOrderedBytes( + value, inputBuffer(position, ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE)) + .array(); + }, + DataTypes.BinaryType) + .withName("TINY_ORDERED_BYTES"); + + this.inputCol++; + increaseOutputSize(ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE); + + return udf; + } + + private UserDefinedFunction shortToOrderedBytesUDF() { + int position = inputCol; + UserDefinedFunction udf = + functions + .udf( + (Short value) -> { + if (value == null) { + return PRIMITIVE_EMPTY; + } + return ZOrderByteUtils.shortToOrderedBytes( + value, inputBuffer(position, ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE)) + .array(); + }, + DataTypes.BinaryType) + .withName("SHORT_ORDERED_BYTES"); + + this.inputCol++; + increaseOutputSize(ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE); + + return udf; + } + + private UserDefinedFunction intToOrderedBytesUDF() { + int position = inputCol; + UserDefinedFunction udf = + functions + .udf( + (Integer value) -> { + if (value == null) { + return PRIMITIVE_EMPTY; + } + return ZOrderByteUtils.intToOrderedBytes( + value, inputBuffer(position, ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE)) + .array(); + }, + DataTypes.BinaryType) + .withName("INT_ORDERED_BYTES"); + + this.inputCol++; + increaseOutputSize(ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE); + + return udf; + } + + private UserDefinedFunction longToOrderedBytesUDF() { + int position = inputCol; + UserDefinedFunction udf = + functions + .udf( + (Long value) -> { + if (value == null) { + return PRIMITIVE_EMPTY; + } + return ZOrderByteUtils.longToOrderedBytes( + value, inputBuffer(position, ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE)) + .array(); + }, + DataTypes.BinaryType) + .withName("LONG_ORDERED_BYTES"); + + this.inputCol++; + increaseOutputSize(ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE); + + return udf; + } + + private UserDefinedFunction floatToOrderedBytesUDF() { + int position = inputCol; + UserDefinedFunction udf = + functions + .udf( + (Float value) -> { + if (value == null) { + return PRIMITIVE_EMPTY; + } + return ZOrderByteUtils.floatToOrderedBytes( + value, inputBuffer(position, ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE)) + .array(); + }, + DataTypes.BinaryType) + .withName("FLOAT_ORDERED_BYTES"); + + this.inputCol++; + increaseOutputSize(ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE); + + return udf; + } + + private UserDefinedFunction doubleToOrderedBytesUDF() { + int position = inputCol; + UserDefinedFunction udf = + functions + .udf( + (Double value) -> { + if (value == null) { + return PRIMITIVE_EMPTY; + } + return ZOrderByteUtils.doubleToOrderedBytes( + value, inputBuffer(position, ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE)) + .array(); + }, + DataTypes.BinaryType) + .withName("DOUBLE_ORDERED_BYTES"); + + this.inputCol++; + increaseOutputSize(ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE); + + return udf; + } + + private UserDefinedFunction booleanToOrderedBytesUDF() { + int position = inputCol; + UserDefinedFunction udf = + functions + .udf( + (Boolean value) -> { + ByteBuffer buffer = inputBuffer(position, ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE); + buffer.put(0, (byte) (value ? -127 : 0)); + return buffer.array(); + }, + DataTypes.BinaryType) + .withName("BOOLEAN-LEXICAL-BYTES"); + + this.inputCol++; + increaseOutputSize(ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE); + return udf; + } + + private UserDefinedFunction stringToOrderedBytesUDF() { + int position = inputCol; + UserDefinedFunction udf = + functions + .udf( + (String value) -> + ZOrderByteUtils.stringToOrderedBytes( + value, varTypeSize, inputBuffer(position, varTypeSize), encoder.get()) + .array(), + DataTypes.BinaryType) + .withName("STRING-LEXICAL-BYTES"); + + this.inputCol++; + increaseOutputSize(varTypeSize); + + return udf; + } + + private UserDefinedFunction bytesTruncateUDF() { + int position = inputCol; + UserDefinedFunction udf = + functions + .udf( + (byte[] value) -> + ZOrderByteUtils.byteTruncateOrFill( + value, varTypeSize, inputBuffer(position, varTypeSize)) + .array(), + DataTypes.BinaryType) + .withName("BYTE-TRUNCATE"); + + this.inputCol++; + increaseOutputSize(varTypeSize); + + return udf; + } + + private final UserDefinedFunction interleaveUDF = + functions + .udf((Seq arrayBinary) -> interleaveBits(arrayBinary), DataTypes.BinaryType) + .withName("INTERLEAVE_BYTES"); + + Column interleaveBytes(Column arrayBinary) { + return interleaveUDF.apply(arrayBinary); + } + + @SuppressWarnings("checkstyle:CyclomaticComplexity") + Column sortedLexicographically(Column column, DataType type) { + if (type instanceof ByteType) { + return tinyToOrderedBytesUDF().apply(column); + } else if (type instanceof ShortType) { + return shortToOrderedBytesUDF().apply(column); + } else if (type instanceof IntegerType) { + return intToOrderedBytesUDF().apply(column); + } else if (type instanceof LongType) { + return longToOrderedBytesUDF().apply(column); + } else if (type instanceof FloatType) { + return floatToOrderedBytesUDF().apply(column); + } else if (type instanceof DoubleType) { + return doubleToOrderedBytesUDF().apply(column); + } else if (type instanceof StringType) { + return stringToOrderedBytesUDF().apply(column); + } else if (type instanceof BinaryType) { + return bytesTruncateUDF().apply(column); + } else if (type instanceof BooleanType) { + return booleanToOrderedBytesUDF().apply(column); + } else if (type instanceof TimestampType) { + return longToOrderedBytesUDF().apply(column.cast(DataTypes.LongType)); + } else if (type instanceof DateType) { + return longToOrderedBytesUDF().apply(column.cast(DataTypes.LongType)); + } else { + throw new IllegalArgumentException( + String.format( + "Cannot use column %s of type %s in ZOrdering, the type is unsupported", + column, type)); + } + } + + private void increaseOutputSize(int bytes) { + totalOutputBytes = Math.min(totalOutputBytes + bytes, maxOutputSize); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/AvroWithSparkSchemaVisitor.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/AvroWithSparkSchemaVisitor.java new file mode 100644 index 000000000000..74454fc1e466 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/AvroWithSparkSchemaVisitor.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import org.apache.iceberg.avro.AvroWithPartnerByStructureVisitor; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.util.Pair; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +public abstract class AvroWithSparkSchemaVisitor + extends AvroWithPartnerByStructureVisitor { + + @Override + protected boolean isStringType(DataType dataType) { + return dataType instanceof StringType; + } + + @Override + protected boolean isMapType(DataType dataType) { + return dataType instanceof MapType; + } + + @Override + protected DataType arrayElementType(DataType arrayType) { + Preconditions.checkArgument( + arrayType instanceof ArrayType, "Invalid array: %s is not an array", arrayType); + return ((ArrayType) arrayType).elementType(); + } + + @Override + protected DataType mapKeyType(DataType mapType) { + Preconditions.checkArgument(isMapType(mapType), "Invalid map: %s is not a map", mapType); + return ((MapType) mapType).keyType(); + } + + @Override + protected DataType mapValueType(DataType mapType) { + Preconditions.checkArgument(isMapType(mapType), "Invalid map: %s is not a map", mapType); + return ((MapType) mapType).valueType(); + } + + @Override + protected Pair fieldNameAndType(DataType structType, int pos) { + Preconditions.checkArgument( + structType instanceof StructType, "Invalid struct: %s is not a struct", structType); + StructField field = ((StructType) structType).apply(pos); + return Pair.of(field.name(), field.dataType()); + } + + @Override + protected DataType nullType() { + return DataTypes.NullType; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/ParquetWithSparkSchemaVisitor.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/ParquetWithSparkSchemaVisitor.java new file mode 100644 index 000000000000..d74a76f94e87 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/ParquetWithSparkSchemaVisitor.java @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import java.util.Deque; +import java.util.List; +import org.apache.iceberg.avro.AvroSchemaUtil; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.OriginalType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; +import org.apache.parquet.schema.Type.Repetition; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * Visitor for traversing a Parquet type with a companion Spark type. + * + * @param the Java class returned by the visitor + */ +public class ParquetWithSparkSchemaVisitor { + private final Deque fieldNames = Lists.newLinkedList(); + + public static T visit(DataType sType, Type type, ParquetWithSparkSchemaVisitor visitor) { + Preconditions.checkArgument(sType != null, "Invalid DataType: null"); + if (type instanceof MessageType) { + Preconditions.checkArgument( + sType instanceof StructType, "Invalid struct: %s is not a struct", sType); + StructType struct = (StructType) sType; + return visitor.message( + struct, (MessageType) type, visitFields(struct, type.asGroupType(), visitor)); + + } else if (type.isPrimitive()) { + return visitor.primitive(sType, type.asPrimitiveType()); + + } else { + // if not a primitive, the typeId must be a group + GroupType group = type.asGroupType(); + OriginalType annotation = group.getOriginalType(); + if (annotation != null) { + switch (annotation) { + case LIST: + Preconditions.checkArgument( + !group.isRepetition(Repetition.REPEATED), + "Invalid list: top-level group is repeated: %s", + group); + Preconditions.checkArgument( + group.getFieldCount() == 1, + "Invalid list: does not contain single repeated field: %s", + group); + + GroupType repeatedElement = group.getFields().get(0).asGroupType(); + Preconditions.checkArgument( + repeatedElement.isRepetition(Repetition.REPEATED), + "Invalid list: inner group is not repeated"); + Preconditions.checkArgument( + repeatedElement.getFieldCount() <= 1, + "Invalid list: repeated group is not a single field: %s", + group); + + Preconditions.checkArgument( + sType instanceof ArrayType, "Invalid list: %s is not an array", sType); + ArrayType array = (ArrayType) sType; + StructField element = + new StructField( + "element", array.elementType(), array.containsNull(), Metadata.empty()); + + visitor.fieldNames.push(repeatedElement.getName()); + try { + T elementResult = null; + if (repeatedElement.getFieldCount() > 0) { + elementResult = visitField(element, repeatedElement.getType(0), visitor); + } + + return visitor.list(array, group, elementResult); + + } finally { + visitor.fieldNames.pop(); + } + + case MAP: + Preconditions.checkArgument( + !group.isRepetition(Repetition.REPEATED), + "Invalid map: top-level group is repeated: %s", + group); + Preconditions.checkArgument( + group.getFieldCount() == 1, + "Invalid map: does not contain single repeated field: %s", + group); + + GroupType repeatedKeyValue = group.getType(0).asGroupType(); + Preconditions.checkArgument( + repeatedKeyValue.isRepetition(Repetition.REPEATED), + "Invalid map: inner group is not repeated"); + Preconditions.checkArgument( + repeatedKeyValue.getFieldCount() <= 2, + "Invalid map: repeated group does not have 2 fields"); + + Preconditions.checkArgument( + sType instanceof MapType, "Invalid map: %s is not a map", sType); + MapType map = (MapType) sType; + StructField keyField = new StructField("key", map.keyType(), false, Metadata.empty()); + StructField valueField = + new StructField( + "value", map.valueType(), map.valueContainsNull(), Metadata.empty()); + + visitor.fieldNames.push(repeatedKeyValue.getName()); + try { + T keyResult = null; + T valueResult = null; + switch (repeatedKeyValue.getFieldCount()) { + case 2: + // if there are 2 fields, both key and value are projected + keyResult = visitField(keyField, repeatedKeyValue.getType(0), visitor); + valueResult = visitField(valueField, repeatedKeyValue.getType(1), visitor); + break; + case 1: + // if there is just one, use the name to determine what it is + Type keyOrValue = repeatedKeyValue.getType(0); + if (keyOrValue.getName().equalsIgnoreCase("key")) { + keyResult = visitField(keyField, keyOrValue, visitor); + // value result remains null + } else { + valueResult = visitField(valueField, keyOrValue, visitor); + // key result remains null + } + break; + default: + // both results will remain null + } + + return visitor.map(map, group, keyResult, valueResult); + + } finally { + visitor.fieldNames.pop(); + } + + default: + } + } + + Preconditions.checkArgument( + sType instanceof StructType, "Invalid struct: %s is not a struct", sType); + StructType struct = (StructType) sType; + return visitor.struct(struct, group, visitFields(struct, group, visitor)); + } + } + + private static T visitField( + StructField sField, Type field, ParquetWithSparkSchemaVisitor visitor) { + visitor.fieldNames.push(field.getName()); + try { + return visit(sField.dataType(), field, visitor); + } finally { + visitor.fieldNames.pop(); + } + } + + private static List visitFields( + StructType struct, GroupType group, ParquetWithSparkSchemaVisitor visitor) { + StructField[] sFields = struct.fields(); + Preconditions.checkArgument( + sFields.length == group.getFieldCount(), "Structs do not match: %s and %s", struct, group); + List results = Lists.newArrayListWithExpectedSize(group.getFieldCount()); + for (int i = 0; i < sFields.length; i += 1) { + Type field = group.getFields().get(i); + StructField sField = sFields[i]; + Preconditions.checkArgument( + field.getName().equals(AvroSchemaUtil.makeCompatibleName(sField.name())), + "Structs do not match: field %s != %s", + field.getName(), + sField.name()); + results.add(visitField(sField, field, visitor)); + } + + return results; + } + + public T message(StructType sStruct, MessageType message, List fields) { + return null; + } + + public T struct(StructType sStruct, GroupType struct, List fields) { + return null; + } + + public T list(ArrayType sArray, GroupType array, T element) { + return null; + } + + public T map(MapType sMap, GroupType map, T key, T value) { + return null; + } + + public T primitive(DataType sPrimitive, PrimitiveType primitive) { + return null; + } + + protected String[] currentPath() { + return Lists.newArrayList(fieldNames.descendingIterator()).toArray(new String[0]); + } + + protected String[] path(String name) { + List list = Lists.newArrayList(fieldNames.descendingIterator()); + list.add(name); + return list.toArray(new String[0]); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroReader.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroReader.java new file mode 100644 index 000000000000..7d92d963a9f4 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroReader.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; +import org.apache.avro.LogicalType; +import org.apache.avro.LogicalTypes; +import org.apache.avro.Schema; +import org.apache.avro.io.DatumReader; +import org.apache.avro.io.Decoder; +import org.apache.iceberg.avro.AvroSchemaWithTypeVisitor; +import org.apache.iceberg.avro.SupportsRowPosition; +import org.apache.iceberg.avro.ValueReader; +import org.apache.iceberg.avro.ValueReaders; +import org.apache.iceberg.data.avro.DecoderResolver; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; + +/** + * @deprecated will be removed in 1.8.0; use SparkPlannedAvroReader instead. + */ +@Deprecated +public class SparkAvroReader implements DatumReader, SupportsRowPosition { + + private final Schema readSchema; + private final ValueReader reader; + private Schema fileSchema = null; + + /** + * @deprecated will be removed in 1.8.0; use SparkPlannedAvroReader instead. + */ + @Deprecated + public SparkAvroReader(org.apache.iceberg.Schema expectedSchema, Schema readSchema) { + this(expectedSchema, readSchema, ImmutableMap.of()); + } + + /** + * @deprecated will be removed in 1.8.0; use SparkPlannedAvroReader instead. + */ + @Deprecated + @SuppressWarnings("unchecked") + public SparkAvroReader( + org.apache.iceberg.Schema expectedSchema, Schema readSchema, Map constants) { + this.readSchema = readSchema; + this.reader = + (ValueReader) + AvroSchemaWithTypeVisitor.visit(expectedSchema, readSchema, new ReadBuilder(constants)); + } + + @Override + public void setSchema(Schema newFileSchema) { + this.fileSchema = Schema.applyAliases(newFileSchema, readSchema); + } + + @Override + public InternalRow read(InternalRow reuse, Decoder decoder) throws IOException { + return DecoderResolver.resolveAndRead(decoder, readSchema, fileSchema, reader, reuse); + } + + @Override + public void setRowPositionSupplier(Supplier posSupplier) { + if (reader instanceof SupportsRowPosition) { + ((SupportsRowPosition) reader).setRowPositionSupplier(posSupplier); + } + } + + private static class ReadBuilder extends AvroSchemaWithTypeVisitor> { + private final Map idToConstant; + + private ReadBuilder(Map idToConstant) { + this.idToConstant = idToConstant; + } + + @Override + public ValueReader record( + Types.StructType expected, Schema record, List names, List> fields) { + return SparkValueReaders.struct(fields, expected, idToConstant); + } + + @Override + public ValueReader union(Type expected, Schema union, List> options) { + return ValueReaders.union(options); + } + + @Override + public ValueReader array( + Types.ListType expected, Schema array, ValueReader elementReader) { + return SparkValueReaders.array(elementReader); + } + + @Override + public ValueReader map( + Types.MapType expected, Schema map, ValueReader keyReader, ValueReader valueReader) { + return SparkValueReaders.arrayMap(keyReader, valueReader); + } + + @Override + public ValueReader map(Types.MapType expected, Schema map, ValueReader valueReader) { + return SparkValueReaders.map(SparkValueReaders.strings(), valueReader); + } + + @Override + public ValueReader primitive(Type.PrimitiveType expected, Schema primitive) { + LogicalType logicalType = primitive.getLogicalType(); + if (logicalType != null) { + switch (logicalType.getName()) { + case "date": + // Spark uses the same representation + return ValueReaders.ints(); + + case "timestamp-millis": + // adjust to microseconds + ValueReader longs = ValueReaders.longs(); + return (ValueReader) (decoder, ignored) -> longs.read(decoder, null) * 1000L; + + case "timestamp-micros": + // Spark uses the same representation + return ValueReaders.longs(); + + case "decimal": + return SparkValueReaders.decimal( + ValueReaders.decimalBytesReader(primitive), + ((LogicalTypes.Decimal) logicalType).getScale()); + + case "uuid": + return SparkValueReaders.uuids(); + + default: + throw new IllegalArgumentException("Unknown logical type: " + logicalType); + } + } + + switch (primitive.getType()) { + case NULL: + return ValueReaders.nulls(); + case BOOLEAN: + return ValueReaders.booleans(); + case INT: + return ValueReaders.ints(); + case LONG: + return ValueReaders.longs(); + case FLOAT: + return ValueReaders.floats(); + case DOUBLE: + return ValueReaders.doubles(); + case STRING: + return SparkValueReaders.strings(); + case FIXED: + return ValueReaders.fixed(primitive.getFixedSize()); + case BYTES: + return ValueReaders.bytes(); + case ENUM: + return SparkValueReaders.enums(primitive.getEnumSymbols()); + default: + throw new IllegalArgumentException("Unsupported type: " + primitive); + } + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroWriter.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroWriter.java new file mode 100644 index 000000000000..04dfd46a1891 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroWriter.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import java.io.IOException; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; +import org.apache.avro.LogicalType; +import org.apache.avro.LogicalTypes; +import org.apache.avro.Schema; +import org.apache.avro.io.Encoder; +import org.apache.iceberg.FieldMetrics; +import org.apache.iceberg.avro.MetricsAwareDatumWriter; +import org.apache.iceberg.avro.ValueWriter; +import org.apache.iceberg.avro.ValueWriters; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.ByteType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StructType; + +public class SparkAvroWriter implements MetricsAwareDatumWriter { + private final StructType dsSchema; + private ValueWriter writer = null; + + public SparkAvroWriter(StructType dsSchema) { + this.dsSchema = dsSchema; + } + + @Override + @SuppressWarnings("unchecked") + public void setSchema(Schema schema) { + this.writer = + (ValueWriter) + AvroWithSparkSchemaVisitor.visit(dsSchema, schema, new WriteBuilder()); + } + + @Override + public void write(InternalRow datum, Encoder out) throws IOException { + writer.write(datum, out); + } + + @Override + public Stream metrics() { + return writer.metrics(); + } + + private static class WriteBuilder extends AvroWithSparkSchemaVisitor> { + @Override + public ValueWriter record( + DataType struct, Schema record, List names, List> fields) { + return SparkValueWriters.struct( + fields, + IntStream.range(0, names.size()) + .mapToObj(i -> fieldNameAndType(struct, i).second()) + .collect(Collectors.toList())); + } + + @Override + public ValueWriter union(DataType type, Schema union, List> options) { + Preconditions.checkArgument( + options.contains(ValueWriters.nulls()), + "Cannot create writer for non-option union: %s", + union); + Preconditions.checkArgument( + options.size() == 2, "Cannot create writer for non-option union: %s", union); + if (union.getTypes().get(0).getType() == Schema.Type.NULL) { + return ValueWriters.option(0, options.get(1)); + } else { + return ValueWriters.option(1, options.get(0)); + } + } + + @Override + public ValueWriter array(DataType sArray, Schema array, ValueWriter elementWriter) { + return SparkValueWriters.array(elementWriter, arrayElementType(sArray)); + } + + @Override + public ValueWriter map(DataType sMap, Schema map, ValueWriter valueReader) { + return SparkValueWriters.map( + SparkValueWriters.strings(), mapKeyType(sMap), valueReader, mapValueType(sMap)); + } + + @Override + public ValueWriter map( + DataType sMap, Schema map, ValueWriter keyWriter, ValueWriter valueWriter) { + return SparkValueWriters.arrayMap( + keyWriter, mapKeyType(sMap), valueWriter, mapValueType(sMap)); + } + + @Override + public ValueWriter primitive(DataType type, Schema primitive) { + LogicalType logicalType = primitive.getLogicalType(); + if (logicalType != null) { + switch (logicalType.getName()) { + case "date": + // Spark uses the same representation + return ValueWriters.ints(); + + case "timestamp-micros": + // Spark uses the same representation + return ValueWriters.longs(); + + case "decimal": + LogicalTypes.Decimal decimal = (LogicalTypes.Decimal) logicalType; + return SparkValueWriters.decimal(decimal.getPrecision(), decimal.getScale()); + + case "uuid": + return SparkValueWriters.uuids(); + + default: + throw new IllegalArgumentException("Unsupported logical type: " + logicalType); + } + } + + switch (primitive.getType()) { + case NULL: + return ValueWriters.nulls(); + case BOOLEAN: + return ValueWriters.booleans(); + case INT: + if (type instanceof ByteType) { + return ValueWriters.tinyints(); + } else if (type instanceof ShortType) { + return ValueWriters.shorts(); + } + return ValueWriters.ints(); + case LONG: + return ValueWriters.longs(); + case FLOAT: + return ValueWriters.floats(); + case DOUBLE: + return ValueWriters.doubles(); + case STRING: + return SparkValueWriters.strings(); + case FIXED: + return ValueWriters.fixed(primitive.getFixedSize()); + case BYTES: + return ValueWriters.bytes(); + default: + throw new IllegalArgumentException("Unsupported type: " + primitive); + } + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcReader.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcReader.java new file mode 100644 index 000000000000..c20be44f6735 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcReader.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.orc.OrcRowReader; +import org.apache.iceberg.orc.OrcSchemaWithTypeVisitor; +import org.apache.iceberg.orc.OrcValueReader; +import org.apache.iceberg.orc.OrcValueReaders; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.orc.TypeDescription; +import org.apache.orc.storage.ql.exec.vector.StructColumnVector; +import org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch; +import org.apache.spark.sql.catalyst.InternalRow; + +/** + * Converts the OrcIterator, which returns ORC's VectorizedRowBatch to a set of Spark's UnsafeRows. + * + *

It minimizes allocations by reusing most of the objects in the implementation. + */ +public class SparkOrcReader implements OrcRowReader { + private final OrcValueReader reader; + + public SparkOrcReader(org.apache.iceberg.Schema expectedSchema, TypeDescription readSchema) { + this(expectedSchema, readSchema, ImmutableMap.of()); + } + + @SuppressWarnings("unchecked") + public SparkOrcReader( + org.apache.iceberg.Schema expectedSchema, + TypeDescription readOrcSchema, + Map idToConstant) { + this.reader = + OrcSchemaWithTypeVisitor.visit( + expectedSchema, readOrcSchema, new ReadBuilder(idToConstant)); + } + + @Override + public InternalRow read(VectorizedRowBatch batch, int row) { + return (InternalRow) reader.read(new StructColumnVector(batch.size, batch.cols), row); + } + + @Override + public void setBatchContext(long batchOffsetInFile) { + reader.setBatchContext(batchOffsetInFile); + } + + private static class ReadBuilder extends OrcSchemaWithTypeVisitor> { + private final Map idToConstant; + + private ReadBuilder(Map idToConstant) { + this.idToConstant = idToConstant; + } + + @Override + public OrcValueReader record( + Types.StructType expected, + TypeDescription record, + List names, + List> fields) { + return SparkOrcValueReaders.struct(fields, expected, idToConstant); + } + + @Override + public OrcValueReader list( + Types.ListType iList, TypeDescription array, OrcValueReader elementReader) { + return SparkOrcValueReaders.array(elementReader); + } + + @Override + public OrcValueReader map( + Types.MapType iMap, + TypeDescription map, + OrcValueReader keyReader, + OrcValueReader valueReader) { + return SparkOrcValueReaders.map(keyReader, valueReader); + } + + @Override + public OrcValueReader primitive(Type.PrimitiveType iPrimitive, TypeDescription primitive) { + switch (primitive.getCategory()) { + case BOOLEAN: + return OrcValueReaders.booleans(); + case BYTE: + // Iceberg does not have a byte type. Use int + case SHORT: + // Iceberg does not have a short type. Use int + case DATE: + case INT: + return OrcValueReaders.ints(); + case LONG: + return OrcValueReaders.longs(); + case FLOAT: + return OrcValueReaders.floats(); + case DOUBLE: + return OrcValueReaders.doubles(); + case TIMESTAMP_INSTANT: + case TIMESTAMP: + return SparkOrcValueReaders.timestampTzs(); + case DECIMAL: + return SparkOrcValueReaders.decimals(primitive.getPrecision(), primitive.getScale()); + case CHAR: + case VARCHAR: + case STRING: + return SparkOrcValueReaders.utf8String(); + case BINARY: + if (Type.TypeID.UUID == iPrimitive.typeId()) { + return SparkOrcValueReaders.uuids(); + } + return OrcValueReaders.bytes(); + default: + throw new IllegalArgumentException("Unhandled type " + primitive); + } + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java new file mode 100644 index 000000000000..670537fbf872 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.orc.OrcValueReader; +import org.apache.iceberg.orc.OrcValueReaders; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.UUIDUtil; +import org.apache.orc.storage.ql.exec.vector.BytesColumnVector; +import org.apache.orc.storage.ql.exec.vector.ColumnVector; +import org.apache.orc.storage.ql.exec.vector.DecimalColumnVector; +import org.apache.orc.storage.ql.exec.vector.ListColumnVector; +import org.apache.orc.storage.ql.exec.vector.MapColumnVector; +import org.apache.orc.storage.ql.exec.vector.TimestampColumnVector; +import org.apache.orc.storage.serde2.io.HiveDecimalWritable; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.GenericArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.UTF8String; + +public class SparkOrcValueReaders { + private SparkOrcValueReaders() {} + + public static OrcValueReader utf8String() { + return StringReader.INSTANCE; + } + + public static OrcValueReader uuids() { + return UUIDReader.INSTANCE; + } + + public static OrcValueReader timestampTzs() { + return TimestampTzReader.INSTANCE; + } + + public static OrcValueReader decimals(int precision, int scale) { + if (precision <= Decimal.MAX_LONG_DIGITS()) { + return new SparkOrcValueReaders.Decimal18Reader(precision, scale); + } else if (precision <= 38) { + return new SparkOrcValueReaders.Decimal38Reader(precision, scale); + } else { + throw new IllegalArgumentException("Invalid precision: " + precision); + } + } + + static OrcValueReader struct( + List> readers, Types.StructType struct, Map idToConstant) { + return new StructReader(readers, struct, idToConstant); + } + + static OrcValueReader array(OrcValueReader elementReader) { + return new ArrayReader(elementReader); + } + + static OrcValueReader map(OrcValueReader keyReader, OrcValueReader valueReader) { + return new MapReader(keyReader, valueReader); + } + + private static class ArrayReader implements OrcValueReader { + private final OrcValueReader elementReader; + + private ArrayReader(OrcValueReader elementReader) { + this.elementReader = elementReader; + } + + @Override + public ArrayData nonNullRead(ColumnVector vector, int row) { + ListColumnVector listVector = (ListColumnVector) vector; + int offset = (int) listVector.offsets[row]; + int length = (int) listVector.lengths[row]; + List elements = Lists.newArrayListWithExpectedSize(length); + for (int c = 0; c < length; ++c) { + elements.add(elementReader.read(listVector.child, offset + c)); + } + return new GenericArrayData(elements.toArray()); + } + + @Override + public void setBatchContext(long batchOffsetInFile) { + elementReader.setBatchContext(batchOffsetInFile); + } + } + + private static class MapReader implements OrcValueReader { + private final OrcValueReader keyReader; + private final OrcValueReader valueReader; + + private MapReader(OrcValueReader keyReader, OrcValueReader valueReader) { + this.keyReader = keyReader; + this.valueReader = valueReader; + } + + @Override + public MapData nonNullRead(ColumnVector vector, int row) { + MapColumnVector mapVector = (MapColumnVector) vector; + int offset = (int) mapVector.offsets[row]; + long length = mapVector.lengths[row]; + List keys = Lists.newArrayListWithExpectedSize((int) length); + List values = Lists.newArrayListWithExpectedSize((int) length); + for (int c = 0; c < length; c++) { + keys.add(keyReader.read(mapVector.keys, offset + c)); + values.add(valueReader.read(mapVector.values, offset + c)); + } + + return new ArrayBasedMapData( + new GenericArrayData(keys.toArray()), new GenericArrayData(values.toArray())); + } + + @Override + public void setBatchContext(long batchOffsetInFile) { + keyReader.setBatchContext(batchOffsetInFile); + valueReader.setBatchContext(batchOffsetInFile); + } + } + + static class StructReader extends OrcValueReaders.StructReader { + private final int numFields; + + protected StructReader( + List> readers, Types.StructType struct, Map idToConstant) { + super(readers, struct, idToConstant); + this.numFields = struct.fields().size(); + } + + @Override + protected InternalRow create() { + return new GenericInternalRow(numFields); + } + + @Override + protected void set(InternalRow struct, int pos, Object value) { + if (value != null) { + struct.update(pos, value); + } else { + struct.setNullAt(pos); + } + } + } + + private static class StringReader implements OrcValueReader { + private static final StringReader INSTANCE = new StringReader(); + + private StringReader() {} + + @Override + public UTF8String nonNullRead(ColumnVector vector, int row) { + BytesColumnVector bytesVector = (BytesColumnVector) vector; + return UTF8String.fromBytes( + bytesVector.vector[row], bytesVector.start[row], bytesVector.length[row]); + } + } + + private static class UUIDReader implements OrcValueReader { + private static final UUIDReader INSTANCE = new UUIDReader(); + + private UUIDReader() {} + + @Override + public UTF8String nonNullRead(ColumnVector vector, int row) { + BytesColumnVector bytesVector = (BytesColumnVector) vector; + ByteBuffer buffer = + ByteBuffer.wrap(bytesVector.vector[row], bytesVector.start[row], bytesVector.length[row]); + return UTF8String.fromString(UUIDUtil.convert(buffer).toString()); + } + } + + private static class TimestampTzReader implements OrcValueReader { + private static final TimestampTzReader INSTANCE = new TimestampTzReader(); + + private TimestampTzReader() {} + + @Override + public Long nonNullRead(ColumnVector vector, int row) { + TimestampColumnVector tcv = (TimestampColumnVector) vector; + return Math.floorDiv(tcv.time[row], 1_000) * 1_000_000 + Math.floorDiv(tcv.nanos[row], 1000); + } + } + + private static class Decimal18Reader implements OrcValueReader { + private final int precision; + private final int scale; + + Decimal18Reader(int precision, int scale) { + this.precision = precision; + this.scale = scale; + } + + @Override + public Decimal nonNullRead(ColumnVector vector, int row) { + HiveDecimalWritable value = ((DecimalColumnVector) vector).vector[row]; + + // The scale of decimal read from hive ORC file may be not equals to the expected scale. For + // data type + // decimal(10,3) and the value 10.100, the hive ORC writer will remove its trailing zero and + // store it + // as 101*10^(-1), its scale will adjust from 3 to 1. So here we could not assert that + // value.scale() == scale. + // we also need to convert the hive orc decimal to a decimal with expected precision and + // scale. + Preconditions.checkArgument( + value.precision() <= precision, + "Cannot read value as decimal(%s,%s), too large: %s", + precision, + scale, + value); + + return new Decimal().set(value.serialize64(scale), precision, scale); + } + } + + private static class Decimal38Reader implements OrcValueReader { + private final int precision; + private final int scale; + + Decimal38Reader(int precision, int scale) { + this.precision = precision; + this.scale = scale; + } + + @Override + public Decimal nonNullRead(ColumnVector vector, int row) { + BigDecimal value = + ((DecimalColumnVector) vector).vector[row].getHiveDecimal().bigDecimalValue(); + + Preconditions.checkArgument( + value.precision() <= precision, + "Cannot read value as decimal(%s,%s), too large: %s", + precision, + scale, + value); + + return new Decimal().set(new scala.math.BigDecimal(value), precision, scale); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriters.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriters.java new file mode 100644 index 000000000000..7f9810e4c60c --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriters.java @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import java.nio.ByteBuffer; +import java.util.List; +import java.util.UUID; +import java.util.stream.Stream; +import org.apache.iceberg.FieldMetrics; +import org.apache.iceberg.orc.OrcValueWriter; +import org.apache.iceberg.util.UUIDUtil; +import org.apache.orc.TypeDescription; +import org.apache.orc.storage.common.type.HiveDecimal; +import org.apache.orc.storage.ql.exec.vector.BytesColumnVector; +import org.apache.orc.storage.ql.exec.vector.ColumnVector; +import org.apache.orc.storage.ql.exec.vector.DecimalColumnVector; +import org.apache.orc.storage.ql.exec.vector.ListColumnVector; +import org.apache.orc.storage.ql.exec.vector.MapColumnVector; +import org.apache.orc.storage.ql.exec.vector.TimestampColumnVector; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.UTF8String; + +class SparkOrcValueWriters { + private SparkOrcValueWriters() {} + + static OrcValueWriter strings() { + return StringWriter.INSTANCE; + } + + static OrcValueWriter uuids() { + return UUIDWriter.INSTANCE; + } + + static OrcValueWriter timestampTz() { + return TimestampTzWriter.INSTANCE; + } + + static OrcValueWriter decimal(int precision, int scale) { + if (precision <= 18) { + return new Decimal18Writer(scale); + } else { + return new Decimal38Writer(); + } + } + + static OrcValueWriter list(OrcValueWriter element, List orcType) { + return new ListWriter<>(element, orcType); + } + + static OrcValueWriter map( + OrcValueWriter keyWriter, OrcValueWriter valueWriter, List orcTypes) { + return new MapWriter<>(keyWriter, valueWriter, orcTypes); + } + + private static class StringWriter implements OrcValueWriter { + private static final StringWriter INSTANCE = new StringWriter(); + + @Override + public void nonNullWrite(int rowId, UTF8String data, ColumnVector output) { + byte[] value = data.getBytes(); + ((BytesColumnVector) output).setRef(rowId, value, 0, value.length); + } + } + + private static class UUIDWriter implements OrcValueWriter { + private static final UUIDWriter INSTANCE = new UUIDWriter(); + + @Override + public void nonNullWrite(int rowId, UTF8String data, ColumnVector output) { + // ((BytesColumnVector) output).setRef(..) just stores a reference to the passed byte[], so + // can't use a ThreadLocal ByteBuffer here like in other places because subsequent writes + // would then overwrite previous values + ByteBuffer buffer = UUIDUtil.convertToByteBuffer(UUID.fromString(data.toString())); + ((BytesColumnVector) output).setRef(rowId, buffer.array(), 0, buffer.array().length); + } + } + + private static class TimestampTzWriter implements OrcValueWriter { + private static final TimestampTzWriter INSTANCE = new TimestampTzWriter(); + + @Override + public void nonNullWrite(int rowId, Long micros, ColumnVector output) { + TimestampColumnVector cv = (TimestampColumnVector) output; + cv.time[rowId] = Math.floorDiv(micros, 1_000); // millis + cv.nanos[rowId] = (int) Math.floorMod(micros, 1_000_000) * 1_000; // nanos + } + } + + private static class Decimal18Writer implements OrcValueWriter { + private final int scale; + + Decimal18Writer(int scale) { + this.scale = scale; + } + + @Override + public void nonNullWrite(int rowId, Decimal decimal, ColumnVector output) { + ((DecimalColumnVector) output) + .vector[rowId].setFromLongAndScale(decimal.toUnscaledLong(), scale); + } + } + + private static class Decimal38Writer implements OrcValueWriter { + + @Override + public void nonNullWrite(int rowId, Decimal decimal, ColumnVector output) { + ((DecimalColumnVector) output) + .vector[rowId].set(HiveDecimal.create(decimal.toJavaBigDecimal())); + } + } + + private static class ListWriter implements OrcValueWriter { + private final OrcValueWriter writer; + private final SparkOrcWriter.FieldGetter fieldGetter; + + @SuppressWarnings("unchecked") + ListWriter(OrcValueWriter writer, List orcTypes) { + if (orcTypes.size() != 1) { + throw new IllegalArgumentException( + "Expected one (and same) ORC type for list elements, got: " + orcTypes); + } + this.writer = writer; + this.fieldGetter = + (SparkOrcWriter.FieldGetter) SparkOrcWriter.createFieldGetter(orcTypes.get(0)); + } + + @Override + public void nonNullWrite(int rowId, ArrayData value, ColumnVector output) { + ListColumnVector cv = (ListColumnVector) output; + // record the length and start of the list elements + cv.lengths[rowId] = value.numElements(); + cv.offsets[rowId] = cv.childCount; + cv.childCount = (int) (cv.childCount + cv.lengths[rowId]); + // make sure the child is big enough + growColumnVector(cv.child, cv.childCount); + // Add each element + for (int e = 0; e < cv.lengths[rowId]; ++e) { + writer.write((int) (e + cv.offsets[rowId]), fieldGetter.getFieldOrNull(value, e), cv.child); + } + } + + @Override + public Stream> metrics() { + return writer.metrics(); + } + } + + private static class MapWriter implements OrcValueWriter { + private final OrcValueWriter keyWriter; + private final OrcValueWriter valueWriter; + private final SparkOrcWriter.FieldGetter keyFieldGetter; + private final SparkOrcWriter.FieldGetter valueFieldGetter; + + @SuppressWarnings("unchecked") + MapWriter( + OrcValueWriter keyWriter, + OrcValueWriter valueWriter, + List orcTypes) { + if (orcTypes.size() != 2) { + throw new IllegalArgumentException( + "Expected two ORC type descriptions for a map, got: " + orcTypes); + } + this.keyWriter = keyWriter; + this.valueWriter = valueWriter; + this.keyFieldGetter = + (SparkOrcWriter.FieldGetter) SparkOrcWriter.createFieldGetter(orcTypes.get(0)); + this.valueFieldGetter = + (SparkOrcWriter.FieldGetter) SparkOrcWriter.createFieldGetter(orcTypes.get(1)); + } + + @Override + public void nonNullWrite(int rowId, MapData map, ColumnVector output) { + ArrayData key = map.keyArray(); + ArrayData value = map.valueArray(); + MapColumnVector cv = (MapColumnVector) output; + // record the length and start of the list elements + cv.lengths[rowId] = value.numElements(); + cv.offsets[rowId] = cv.childCount; + cv.childCount = (int) (cv.childCount + cv.lengths[rowId]); + // make sure the child is big enough + growColumnVector(cv.keys, cv.childCount); + growColumnVector(cv.values, cv.childCount); + // Add each element + for (int e = 0; e < cv.lengths[rowId]; ++e) { + int pos = (int) (e + cv.offsets[rowId]); + keyWriter.write(pos, keyFieldGetter.getFieldOrNull(key, e), cv.keys); + valueWriter.write(pos, valueFieldGetter.getFieldOrNull(value, e), cv.values); + } + } + + @Override + public Stream> metrics() { + return Stream.concat(keyWriter.metrics(), valueWriter.metrics()); + } + } + + private static void growColumnVector(ColumnVector cv, int requestedSize) { + if (cv.isNull.length < requestedSize) { + // Use growth factor of 3 to avoid frequent array allocations + cv.ensureSize(requestedSize * 3, true); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcWriter.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcWriter.java new file mode 100644 index 000000000000..6b799e677bf4 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcWriter.java @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import java.io.Serializable; +import java.util.List; +import java.util.stream.Stream; +import javax.annotation.Nullable; +import org.apache.iceberg.FieldMetrics; +import org.apache.iceberg.Schema; +import org.apache.iceberg.data.orc.GenericOrcWriters; +import org.apache.iceberg.orc.ORCSchemaUtil; +import org.apache.iceberg.orc.OrcRowWriter; +import org.apache.iceberg.orc.OrcSchemaWithTypeVisitor; +import org.apache.iceberg.orc.OrcValueWriter; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.orc.TypeDescription; +import org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters; + +/** This class acts as an adaptor from an OrcFileAppender to a FileAppender<InternalRow>. */ +public class SparkOrcWriter implements OrcRowWriter { + + private final InternalRowWriter writer; + + public SparkOrcWriter(Schema iSchema, TypeDescription orcSchema) { + Preconditions.checkArgument( + orcSchema.getCategory() == TypeDescription.Category.STRUCT, + "Top level must be a struct " + orcSchema); + + writer = + (InternalRowWriter) OrcSchemaWithTypeVisitor.visit(iSchema, orcSchema, new WriteBuilder()); + } + + @Override + public void write(InternalRow value, VectorizedRowBatch output) { + Preconditions.checkArgument(value != null, "value must not be null"); + writer.writeRow(value, output); + } + + @Override + public List> writers() { + return writer.writers(); + } + + @Override + public Stream> metrics() { + return writer.metrics(); + } + + private static class WriteBuilder extends OrcSchemaWithTypeVisitor> { + private WriteBuilder() {} + + @Override + public OrcValueWriter record( + Types.StructType iStruct, + TypeDescription record, + List names, + List> fields) { + return new InternalRowWriter(fields, record.getChildren()); + } + + @Override + public OrcValueWriter list( + Types.ListType iList, TypeDescription array, OrcValueWriter element) { + return SparkOrcValueWriters.list(element, array.getChildren()); + } + + @Override + public OrcValueWriter map( + Types.MapType iMap, TypeDescription map, OrcValueWriter key, OrcValueWriter value) { + return SparkOrcValueWriters.map(key, value, map.getChildren()); + } + + @Override + public OrcValueWriter primitive(Type.PrimitiveType iPrimitive, TypeDescription primitive) { + switch (primitive.getCategory()) { + case BOOLEAN: + return GenericOrcWriters.booleans(); + case BYTE: + return GenericOrcWriters.bytes(); + case SHORT: + return GenericOrcWriters.shorts(); + case DATE: + case INT: + return GenericOrcWriters.ints(); + case LONG: + return GenericOrcWriters.longs(); + case FLOAT: + return GenericOrcWriters.floats(ORCSchemaUtil.fieldId(primitive)); + case DOUBLE: + return GenericOrcWriters.doubles(ORCSchemaUtil.fieldId(primitive)); + case BINARY: + if (Type.TypeID.UUID == iPrimitive.typeId()) { + return SparkOrcValueWriters.uuids(); + } + return GenericOrcWriters.byteArrays(); + case STRING: + case CHAR: + case VARCHAR: + return SparkOrcValueWriters.strings(); + case DECIMAL: + return SparkOrcValueWriters.decimal(primitive.getPrecision(), primitive.getScale()); + case TIMESTAMP_INSTANT: + case TIMESTAMP: + return SparkOrcValueWriters.timestampTz(); + default: + throw new IllegalArgumentException("Unhandled type " + primitive); + } + } + } + + private static class InternalRowWriter extends GenericOrcWriters.StructWriter { + private final List> fieldGetters; + + InternalRowWriter(List> writers, List orcTypes) { + super(writers); + this.fieldGetters = Lists.newArrayListWithExpectedSize(orcTypes.size()); + + for (TypeDescription orcType : orcTypes) { + fieldGetters.add(createFieldGetter(orcType)); + } + } + + @Override + protected Object get(InternalRow struct, int index) { + return fieldGetters.get(index).getFieldOrNull(struct, index); + } + } + + static FieldGetter createFieldGetter(TypeDescription fieldType) { + final FieldGetter fieldGetter; + switch (fieldType.getCategory()) { + case BOOLEAN: + fieldGetter = SpecializedGetters::getBoolean; + break; + case BYTE: + fieldGetter = SpecializedGetters::getByte; + break; + case SHORT: + fieldGetter = SpecializedGetters::getShort; + break; + case DATE: + case INT: + fieldGetter = SpecializedGetters::getInt; + break; + case LONG: + case TIMESTAMP: + case TIMESTAMP_INSTANT: + fieldGetter = SpecializedGetters::getLong; + break; + case FLOAT: + fieldGetter = SpecializedGetters::getFloat; + break; + case DOUBLE: + fieldGetter = SpecializedGetters::getDouble; + break; + case BINARY: + if (ORCSchemaUtil.BinaryType.UUID + .toString() + .equalsIgnoreCase( + fieldType.getAttributeValue(ORCSchemaUtil.ICEBERG_BINARY_TYPE_ATTRIBUTE))) { + fieldGetter = SpecializedGetters::getUTF8String; + } else { + fieldGetter = SpecializedGetters::getBinary; + } + // getBinary always makes a copy, so we don't need to worry about it + // being changed behind our back. + break; + case DECIMAL: + fieldGetter = + (row, ordinal) -> + row.getDecimal(ordinal, fieldType.getPrecision(), fieldType.getScale()); + break; + case STRING: + case CHAR: + case VARCHAR: + fieldGetter = SpecializedGetters::getUTF8String; + break; + case STRUCT: + fieldGetter = (row, ordinal) -> row.getStruct(ordinal, fieldType.getChildren().size()); + break; + case LIST: + fieldGetter = SpecializedGetters::getArray; + break; + case MAP: + fieldGetter = SpecializedGetters::getMap; + break; + default: + throw new IllegalArgumentException( + "Encountered an unsupported ORC type during a write from Spark."); + } + + return (row, ordinal) -> { + if (row.isNullAt(ordinal)) { + return null; + } + return fieldGetter.getFieldOrNull(row, ordinal); + }; + } + + interface FieldGetter extends Serializable { + + /** + * Returns a value from a complex Spark data holder such ArrayData, InternalRow, etc... Calls + * the appropriate getter for the expected data type. + * + * @param row Spark's data representation + * @param ordinal index in the data structure (e.g. column index for InterRow, list index in + * ArrayData, etc..) + * @return field value at ordinal + */ + @Nullable + T getFieldOrNull(SpecializedGetters row, int ordinal); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetReaders.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetReaders.java new file mode 100644 index 000000000000..687b26b83187 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetReaders.java @@ -0,0 +1,792 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Schema; +import org.apache.iceberg.parquet.ParquetSchemaUtil; +import org.apache.iceberg.parquet.ParquetUtil; +import org.apache.iceberg.parquet.ParquetValueReader; +import org.apache.iceberg.parquet.ParquetValueReaders; +import org.apache.iceberg.parquet.ParquetValueReaders.FloatAsDoubleReader; +import org.apache.iceberg.parquet.ParquetValueReaders.IntAsLongReader; +import org.apache.iceberg.parquet.ParquetValueReaders.PrimitiveReader; +import org.apache.iceberg.parquet.ParquetValueReaders.RepeatedKeyValueReader; +import org.apache.iceberg.parquet.ParquetValueReaders.RepeatedReader; +import org.apache.iceberg.parquet.ParquetValueReaders.ReusableEntry; +import org.apache.iceberg.parquet.ParquetValueReaders.StructReader; +import org.apache.iceberg.parquet.ParquetValueReaders.UnboxedReader; +import org.apache.iceberg.parquet.TypeWithSchemaVisitor; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Type.TypeID; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.UUIDUtil; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.io.api.Binary; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.LogicalTypeAnnotation.DecimalLogicalTypeAnnotation; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.GenericArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.unsafe.types.VariantVal; + +public class SparkParquetReaders { + private SparkParquetReaders() {} + + public static ParquetValueReader buildReader( + Schema expectedSchema, MessageType fileSchema) { + return buildReader(expectedSchema, fileSchema, ImmutableMap.of()); + } + + @SuppressWarnings("unchecked") + public static ParquetValueReader buildReader( + Schema expectedSchema, MessageType fileSchema, Map idToConstant) { + if (ParquetSchemaUtil.hasIds(fileSchema)) { + return (ParquetValueReader) + TypeWithSchemaVisitor.visit( + expectedSchema.asStruct(), fileSchema, new ReadBuilder(fileSchema, idToConstant)); + } else { + return (ParquetValueReader) + TypeWithSchemaVisitor.visit( + expectedSchema.asStruct(), + fileSchema, + new FallbackReadBuilder(fileSchema, idToConstant)); + } + } + + private static class FallbackReadBuilder extends ReadBuilder { + FallbackReadBuilder(MessageType type, Map idToConstant) { + super(type, idToConstant); + } + + @Override + public ParquetValueReader message( + Types.StructType expected, MessageType message, List> fieldReaders) { + // the top level matches by ID, but the remaining IDs are missing + return super.struct(expected, message, fieldReaders); + } + + @Override + public ParquetValueReader struct( + Types.StructType ignored, GroupType struct, List> fieldReaders) { + // the expected struct is ignored because nested fields are never found when the + List> newFields = + Lists.newArrayListWithExpectedSize(fieldReaders.size()); + List types = Lists.newArrayListWithExpectedSize(fieldReaders.size()); + List fields = struct.getFields(); + for (int i = 0; i < fields.size(); i += 1) { + Type fieldType = fields.get(i); + int fieldD = type().getMaxDefinitionLevel(path(fieldType.getName())) - 1; + newFields.add(ParquetValueReaders.option(fieldType, fieldD, fieldReaders.get(i))); + types.add(fieldType); + } + + return new InternalRowReader(types, newFields); + } + } + + private static class ReadBuilder extends TypeWithSchemaVisitor> { + private final MessageType type; + private final Map idToConstant; + + ReadBuilder(MessageType type, Map idToConstant) { + this.type = type; + this.idToConstant = idToConstant; + } + + @Override + public ParquetValueReader message( + Types.StructType expected, MessageType message, List> fieldReaders) { + return struct(expected, message.asGroupType(), fieldReaders); + } + + @Override + public ParquetValueReader struct( + Types.StructType expected, GroupType struct, List> fieldReaders) { + // match the expected struct's order + Map> readersById = Maps.newHashMap(); + Map typesById = Maps.newHashMap(); + Map maxDefinitionLevelsById = Maps.newHashMap(); + List fields = struct.getFields(); + for (int i = 0; i < fields.size(); i += 1) { + Type fieldType = fields.get(i); + int fieldD = type.getMaxDefinitionLevel(path(fieldType.getName())) - 1; + if (fieldType.getId() != null) { + int id = fieldType.getId().intValue(); + readersById.put(id, ParquetValueReaders.option(fieldType, fieldD, fieldReaders.get(i))); + typesById.put(id, fieldType); + if (idToConstant.containsKey(id)) { + maxDefinitionLevelsById.put(id, fieldD); + } + } + } + + List expectedFields = + expected != null ? expected.fields() : ImmutableList.of(); + List> reorderedFields = + Lists.newArrayListWithExpectedSize(expectedFields.size()); + List types = Lists.newArrayListWithExpectedSize(expectedFields.size()); + // Defaulting to parent max definition level + int defaultMaxDefinitionLevel = type.getMaxDefinitionLevel(currentPath()); + for (Types.NestedField field : expectedFields) { + int id = field.fieldId(); + if (idToConstant.containsKey(id)) { + // containsKey is used because the constant may be null + int fieldMaxDefinitionLevel = + maxDefinitionLevelsById.getOrDefault(id, defaultMaxDefinitionLevel); + reorderedFields.add( + ParquetValueReaders.constant(idToConstant.get(id), fieldMaxDefinitionLevel)); + types.add(null); + } else if (id == MetadataColumns.ROW_POSITION.fieldId()) { + reorderedFields.add(ParquetValueReaders.position()); + types.add(null); + } else if (id == MetadataColumns.IS_DELETED.fieldId()) { + reorderedFields.add(ParquetValueReaders.constant(false)); + types.add(null); + } else { + ParquetValueReader reader = readersById.get(id); + if (reader != null) { + reorderedFields.add(reader); + types.add(typesById.get(id)); + } else { + reorderedFields.add(ParquetValueReaders.nulls()); + types.add(null); + } + } + } + + return new InternalRowReader(types, reorderedFields); + } + + @Override + public ParquetValueReader list( + Types.ListType expectedList, GroupType array, ParquetValueReader elementReader) { + String[] repeatedPath = currentPath(); + + int repeatedD = type.getMaxDefinitionLevel(repeatedPath) - 1; + int repeatedR = type.getMaxRepetitionLevel(repeatedPath) - 1; + + Type elementType = ParquetSchemaUtil.determineListElementType(array); + int elementD = type.getMaxDefinitionLevel(path(elementType.getName())) - 1; + + return new ArrayReader<>( + repeatedD, repeatedR, ParquetValueReaders.option(elementType, elementD, elementReader)); + } + + @Override + public ParquetValueReader map( + Types.MapType expectedMap, + GroupType map, + ParquetValueReader keyReader, + ParquetValueReader valueReader) { + GroupType repeatedKeyValue = map.getFields().get(0).asGroupType(); + String[] repeatedPath = currentPath(); + + int repeatedD = type.getMaxDefinitionLevel(repeatedPath) - 1; + int repeatedR = type.getMaxRepetitionLevel(repeatedPath) - 1; + + Type keyType = repeatedKeyValue.getType(0); + int keyD = type.getMaxDefinitionLevel(path(keyType.getName())) - 1; + Type valueType = repeatedKeyValue.getType(1); + int valueD = type.getMaxDefinitionLevel(path(valueType.getName())) - 1; + + return new MapReader<>( + repeatedD, + repeatedR, + ParquetValueReaders.option(keyType, keyD, keyReader), + ParquetValueReaders.option(valueType, valueD, valueReader)); + } + + @Override + @SuppressWarnings("checkstyle:CyclomaticComplexity") + public ParquetValueReader primitive( + org.apache.iceberg.types.Type.PrimitiveType expected, PrimitiveType primitive) { + ColumnDescriptor desc = type.getColumnDescription(currentPath()); + + if (primitive.getOriginalType() != null) { + switch (primitive.getOriginalType()) { + case ENUM: + case JSON: + case UTF8: + return new StringReader(desc); + case INT_8: + case INT_16: + case INT_32: + if (expected != null && expected.typeId() == Types.LongType.get().typeId()) { + return new IntAsLongReader(desc); + } else { + return new UnboxedReader(desc); + } + case DATE: + case INT_64: + case TIMESTAMP_MICROS: + return new UnboxedReader<>(desc); + case TIMESTAMP_MILLIS: + return new TimestampMillisReader(desc); + case DECIMAL: + DecimalLogicalTypeAnnotation decimal = + (DecimalLogicalTypeAnnotation) primitive.getLogicalTypeAnnotation(); + switch (primitive.getPrimitiveTypeName()) { + case BINARY: + case FIXED_LEN_BYTE_ARRAY: + return new BinaryDecimalReader(desc, decimal.getScale()); + case INT64: + return new LongDecimalReader(desc, decimal.getPrecision(), decimal.getScale()); + case INT32: + return new IntegerDecimalReader(desc, decimal.getPrecision(), decimal.getScale()); + default: + throw new UnsupportedOperationException( + "Unsupported base type for decimal: " + primitive.getPrimitiveTypeName()); + } + case BSON: + return new ParquetValueReaders.ByteArrayReader(desc); + default: + throw new UnsupportedOperationException( + "Unsupported logical type: " + primitive.getOriginalType()); + } + } + + switch (primitive.getPrimitiveTypeName()) { + case FIXED_LEN_BYTE_ARRAY: + case BINARY: + if (expected != null && expected.typeId() == TypeID.UUID) { + return new UUIDReader(desc); + } + return new ParquetValueReaders.ByteArrayReader(desc); + case INT32: + if (expected != null && expected.typeId() == TypeID.LONG) { + return new IntAsLongReader(desc); + } else { + return new UnboxedReader<>(desc); + } + case FLOAT: + if (expected != null && expected.typeId() == TypeID.DOUBLE) { + return new FloatAsDoubleReader(desc); + } else { + return new UnboxedReader<>(desc); + } + case BOOLEAN: + case INT64: + case DOUBLE: + return new UnboxedReader<>(desc); + case INT96: + // Impala & Spark used to write timestamps as INT96 without a logical type. For backwards + // compatibility we try to read INT96 as timestamps. + return new TimestampInt96Reader(desc); + default: + throw new UnsupportedOperationException("Unsupported type: " + primitive); + } + } + + protected MessageType type() { + return type; + } + } + + private static class BinaryDecimalReader extends PrimitiveReader { + private final int scale; + + BinaryDecimalReader(ColumnDescriptor desc, int scale) { + super(desc); + this.scale = scale; + } + + @Override + public Decimal read(Decimal ignored) { + Binary binary = column.nextBinary(); + return Decimal.fromDecimal(new BigDecimal(new BigInteger(binary.getBytes()), scale)); + } + } + + private static class IntegerDecimalReader extends PrimitiveReader { + private final int precision; + private final int scale; + + IntegerDecimalReader(ColumnDescriptor desc, int precision, int scale) { + super(desc); + this.precision = precision; + this.scale = scale; + } + + @Override + public Decimal read(Decimal ignored) { + return Decimal.apply(column.nextInteger(), precision, scale); + } + } + + private static class LongDecimalReader extends PrimitiveReader { + private final int precision; + private final int scale; + + LongDecimalReader(ColumnDescriptor desc, int precision, int scale) { + super(desc); + this.precision = precision; + this.scale = scale; + } + + @Override + public Decimal read(Decimal ignored) { + return Decimal.apply(column.nextLong(), precision, scale); + } + } + + private static class TimestampMillisReader extends UnboxedReader { + TimestampMillisReader(ColumnDescriptor desc) { + super(desc); + } + + @Override + public Long read(Long ignored) { + return readLong(); + } + + @Override + public long readLong() { + return 1000 * column.nextLong(); + } + } + + private static class TimestampInt96Reader extends UnboxedReader { + + TimestampInt96Reader(ColumnDescriptor desc) { + super(desc); + } + + @Override + public Long read(Long ignored) { + return readLong(); + } + + @Override + public long readLong() { + final ByteBuffer byteBuffer = + column.nextBinary().toByteBuffer().order(ByteOrder.LITTLE_ENDIAN); + return ParquetUtil.extractTimestampInt96(byteBuffer); + } + } + + private static class StringReader extends PrimitiveReader { + StringReader(ColumnDescriptor desc) { + super(desc); + } + + @Override + public UTF8String read(UTF8String ignored) { + Binary binary = column.nextBinary(); + ByteBuffer buffer = binary.toByteBuffer(); + if (buffer.hasArray()) { + return UTF8String.fromBytes( + buffer.array(), buffer.arrayOffset() + buffer.position(), buffer.remaining()); + } else { + return UTF8String.fromBytes(binary.getBytes()); + } + } + } + + private static class UUIDReader extends PrimitiveReader { + UUIDReader(ColumnDescriptor desc) { + super(desc); + } + + @Override + @SuppressWarnings("ByteBufferBackingArray") + public UTF8String read(UTF8String ignored) { + return UTF8String.fromString(UUIDUtil.convert(column.nextBinary().toByteBuffer()).toString()); + } + } + + private static class ArrayReader extends RepeatedReader { + private int readPos = 0; + private int writePos = 0; + + ArrayReader(int definitionLevel, int repetitionLevel, ParquetValueReader reader) { + super(definitionLevel, repetitionLevel, reader); + } + + @Override + @SuppressWarnings("unchecked") + protected ReusableArrayData newListData(ArrayData reuse) { + this.readPos = 0; + this.writePos = 0; + + if (reuse instanceof ReusableArrayData) { + return (ReusableArrayData) reuse; + } else { + return new ReusableArrayData(); + } + } + + @Override + @SuppressWarnings("unchecked") + protected E getElement(ReusableArrayData list) { + E value = null; + if (readPos < list.capacity()) { + value = (E) list.values[readPos]; + } + + readPos += 1; + + return value; + } + + @Override + protected void addElement(ReusableArrayData reused, E element) { + if (writePos >= reused.capacity()) { + reused.grow(); + } + + reused.values[writePos] = element; + + writePos += 1; + } + + @Override + protected ArrayData buildList(ReusableArrayData list) { + list.setNumElements(writePos); + return list; + } + } + + private static class MapReader + extends RepeatedKeyValueReader { + private int readPos = 0; + private int writePos = 0; + + private final ReusableEntry entry = new ReusableEntry<>(); + private final ReusableEntry nullEntry = new ReusableEntry<>(); + + MapReader( + int definitionLevel, + int repetitionLevel, + ParquetValueReader keyReader, + ParquetValueReader valueReader) { + super(definitionLevel, repetitionLevel, keyReader, valueReader); + } + + @Override + @SuppressWarnings("unchecked") + protected ReusableMapData newMapData(MapData reuse) { + this.readPos = 0; + this.writePos = 0; + + if (reuse instanceof ReusableMapData) { + return (ReusableMapData) reuse; + } else { + return new ReusableMapData(); + } + } + + @Override + @SuppressWarnings("unchecked") + protected Map.Entry getPair(ReusableMapData map) { + Map.Entry kv = nullEntry; + if (readPos < map.capacity()) { + entry.set((K) map.keys.values[readPos], (V) map.values.values[readPos]); + kv = entry; + } + + readPos += 1; + + return kv; + } + + @Override + protected void addPair(ReusableMapData map, K key, V value) { + if (writePos >= map.capacity()) { + map.grow(); + } + + map.keys.values[writePos] = key; + map.values.values[writePos] = value; + + writePos += 1; + } + + @Override + protected MapData buildMap(ReusableMapData map) { + map.setNumElements(writePos); + return map; + } + } + + private static class InternalRowReader extends StructReader { + private final int numFields; + + InternalRowReader(List types, List> readers) { + super(types, readers); + this.numFields = readers.size(); + } + + @Override + protected GenericInternalRow newStructData(InternalRow reuse) { + if (reuse instanceof GenericInternalRow) { + return (GenericInternalRow) reuse; + } else { + return new GenericInternalRow(numFields); + } + } + + @Override + protected Object getField(GenericInternalRow intermediate, int pos) { + return intermediate.genericGet(pos); + } + + @Override + protected InternalRow buildStruct(GenericInternalRow struct) { + return struct; + } + + @Override + protected void set(GenericInternalRow row, int pos, Object value) { + row.update(pos, value); + } + + @Override + protected void setNull(GenericInternalRow row, int pos) { + row.setNullAt(pos); + } + + @Override + protected void setBoolean(GenericInternalRow row, int pos, boolean value) { + row.setBoolean(pos, value); + } + + @Override + protected void setInteger(GenericInternalRow row, int pos, int value) { + row.setInt(pos, value); + } + + @Override + protected void setLong(GenericInternalRow row, int pos, long value) { + row.setLong(pos, value); + } + + @Override + protected void setFloat(GenericInternalRow row, int pos, float value) { + row.setFloat(pos, value); + } + + @Override + protected void setDouble(GenericInternalRow row, int pos, double value) { + row.setDouble(pos, value); + } + } + + private static class ReusableMapData extends MapData { + private final ReusableArrayData keys; + private final ReusableArrayData values; + private int numElements; + + private ReusableMapData() { + this.keys = new ReusableArrayData(); + this.values = new ReusableArrayData(); + } + + private void grow() { + keys.grow(); + values.grow(); + } + + private int capacity() { + return keys.capacity(); + } + + public void setNumElements(int numElements) { + this.numElements = numElements; + keys.setNumElements(numElements); + values.setNumElements(numElements); + } + + @Override + public int numElements() { + return numElements; + } + + @Override + public MapData copy() { + return new ArrayBasedMapData(keyArray().copy(), valueArray().copy()); + } + + @Override + public ReusableArrayData keyArray() { + return keys; + } + + @Override + public ReusableArrayData valueArray() { + return values; + } + } + + private static class ReusableArrayData extends ArrayData { + private static final Object[] EMPTY = new Object[0]; + + private Object[] values = EMPTY; + private int numElements = 0; + + private void grow() { + if (values.length == 0) { + this.values = new Object[20]; + } else { + Object[] old = values; + this.values = new Object[old.length << 2]; + // copy the old array in case it has values that can be reused + System.arraycopy(old, 0, values, 0, old.length); + } + } + + private int capacity() { + return values.length; + } + + public void setNumElements(int numElements) { + this.numElements = numElements; + } + + @Override + public Object get(int ordinal, DataType dataType) { + return values[ordinal]; + } + + @Override + public int numElements() { + return numElements; + } + + @Override + public ArrayData copy() { + return new GenericArrayData(array()); + } + + @Override + public Object[] array() { + return Arrays.copyOfRange(values, 0, numElements); + } + + @Override + public void setNullAt(int i) { + values[i] = null; + } + + @Override + public void update(int ordinal, Object value) { + values[ordinal] = value; + } + + @Override + public boolean isNullAt(int ordinal) { + return null == values[ordinal]; + } + + @Override + public boolean getBoolean(int ordinal) { + return (boolean) values[ordinal]; + } + + @Override + public byte getByte(int ordinal) { + return (byte) values[ordinal]; + } + + @Override + public short getShort(int ordinal) { + return (short) values[ordinal]; + } + + @Override + public int getInt(int ordinal) { + return (int) values[ordinal]; + } + + @Override + public long getLong(int ordinal) { + return (long) values[ordinal]; + } + + @Override + public float getFloat(int ordinal) { + return (float) values[ordinal]; + } + + @Override + public double getDouble(int ordinal) { + return (double) values[ordinal]; + } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + return (Decimal) values[ordinal]; + } + + @Override + public UTF8String getUTF8String(int ordinal) { + return (UTF8String) values[ordinal]; + } + + @Override + public byte[] getBinary(int ordinal) { + return (byte[]) values[ordinal]; + } + + @Override + public CalendarInterval getInterval(int ordinal) { + return (CalendarInterval) values[ordinal]; + } + + @Override + public InternalRow getStruct(int ordinal, int numFields) { + return (InternalRow) values[ordinal]; + } + + @Override + public ArrayData getArray(int ordinal) { + return (ArrayData) values[ordinal]; + } + + @Override + public MapData getMap(int ordinal) { + return (MapData) values[ordinal]; + } + + @Override + public VariantVal getVariant(int ordinal) { + throw new UnsupportedOperationException("Unsupported method: getVariant"); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java new file mode 100644 index 000000000000..678ebd218d71 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java @@ -0,0 +1,571 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Optional; +import java.util.UUID; +import org.apache.iceberg.parquet.ParquetValueReaders.ReusableEntry; +import org.apache.iceberg.parquet.ParquetValueWriter; +import org.apache.iceberg.parquet.ParquetValueWriters; +import org.apache.iceberg.parquet.ParquetValueWriters.PrimitiveWriter; +import org.apache.iceberg.parquet.ParquetValueWriters.RepeatedKeyValueWriter; +import org.apache.iceberg.parquet.ParquetValueWriters.RepeatedWriter; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.util.DecimalUtil; +import org.apache.iceberg.util.UUIDUtil; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.io.api.Binary; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.LogicalTypeAnnotation; +import org.apache.parquet.schema.LogicalTypeAnnotation.DecimalLogicalTypeAnnotation; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.ByteType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; + +public class SparkParquetWriters { + private SparkParquetWriters() {} + + @SuppressWarnings("unchecked") + public static ParquetValueWriter buildWriter(StructType dfSchema, MessageType type) { + return (ParquetValueWriter) + ParquetWithSparkSchemaVisitor.visit(dfSchema, type, new WriteBuilder(type)); + } + + private static class WriteBuilder extends ParquetWithSparkSchemaVisitor> { + private final MessageType type; + + WriteBuilder(MessageType type) { + this.type = type; + } + + @Override + public ParquetValueWriter message( + StructType sStruct, MessageType message, List> fieldWriters) { + return struct(sStruct, message.asGroupType(), fieldWriters); + } + + @Override + public ParquetValueWriter struct( + StructType sStruct, GroupType struct, List> fieldWriters) { + List fields = struct.getFields(); + StructField[] sparkFields = sStruct.fields(); + List> writers = Lists.newArrayListWithExpectedSize(fieldWriters.size()); + List sparkTypes = Lists.newArrayList(); + for (int i = 0; i < fields.size(); i += 1) { + writers.add(newOption(struct.getType(i), fieldWriters.get(i))); + sparkTypes.add(sparkFields[i].dataType()); + } + return new InternalRowWriter(writers, sparkTypes); + } + + @Override + public ParquetValueWriter list( + ArrayType sArray, GroupType array, ParquetValueWriter elementWriter) { + GroupType repeated = array.getFields().get(0).asGroupType(); + String[] repeatedPath = currentPath(); + + int repeatedD = type.getMaxDefinitionLevel(repeatedPath); + int repeatedR = type.getMaxRepetitionLevel(repeatedPath); + + return new ArrayDataWriter<>( + repeatedD, + repeatedR, + newOption(repeated.getType(0), elementWriter), + sArray.elementType()); + } + + @Override + public ParquetValueWriter map( + MapType sMap, + GroupType map, + ParquetValueWriter keyWriter, + ParquetValueWriter valueWriter) { + GroupType repeatedKeyValue = map.getFields().get(0).asGroupType(); + String[] repeatedPath = currentPath(); + + int repeatedD = type.getMaxDefinitionLevel(repeatedPath); + int repeatedR = type.getMaxRepetitionLevel(repeatedPath); + + return new MapDataWriter<>( + repeatedD, + repeatedR, + newOption(repeatedKeyValue.getType(0), keyWriter), + newOption(repeatedKeyValue.getType(1), valueWriter), + sMap.keyType(), + sMap.valueType()); + } + + private ParquetValueWriter newOption(Type fieldType, ParquetValueWriter writer) { + int maxD = type.getMaxDefinitionLevel(path(fieldType.getName())); + return ParquetValueWriters.option(fieldType, maxD, writer); + } + + private static class LogicalTypeAnnotationParquetValueWriterVisitor + implements LogicalTypeAnnotation.LogicalTypeAnnotationVisitor> { + + private final ColumnDescriptor desc; + private final PrimitiveType primitive; + + LogicalTypeAnnotationParquetValueWriterVisitor( + ColumnDescriptor desc, PrimitiveType primitive) { + this.desc = desc; + this.primitive = primitive; + } + + @Override + public Optional> visit( + LogicalTypeAnnotation.StringLogicalTypeAnnotation stringLogicalType) { + return Optional.of(utf8Strings(desc)); + } + + @Override + public Optional> visit( + LogicalTypeAnnotation.EnumLogicalTypeAnnotation enumLogicalType) { + return Optional.of(utf8Strings(desc)); + } + + @Override + public Optional> visit( + LogicalTypeAnnotation.JsonLogicalTypeAnnotation jsonLogicalType) { + return Optional.of(utf8Strings(desc)); + } + + @Override + public Optional> visit( + LogicalTypeAnnotation.UUIDLogicalTypeAnnotation uuidLogicalType) { + return Optional.of(uuids(desc)); + } + + @Override + public Optional> visit( + LogicalTypeAnnotation.MapLogicalTypeAnnotation mapLogicalType) { + return LogicalTypeAnnotation.LogicalTypeAnnotationVisitor.super.visit(mapLogicalType); + } + + @Override + public Optional> visit( + LogicalTypeAnnotation.ListLogicalTypeAnnotation listLogicalType) { + return LogicalTypeAnnotation.LogicalTypeAnnotationVisitor.super.visit(listLogicalType); + } + + @Override + public Optional> visit(DecimalLogicalTypeAnnotation decimal) { + switch (primitive.getPrimitiveTypeName()) { + case INT32: + return Optional.of(decimalAsInteger(desc, decimal.getPrecision(), decimal.getScale())); + case INT64: + return Optional.of(decimalAsLong(desc, decimal.getPrecision(), decimal.getScale())); + case BINARY: + case FIXED_LEN_BYTE_ARRAY: + return Optional.of(decimalAsFixed(desc, decimal.getPrecision(), decimal.getScale())); + } + return Optional.empty(); + } + + @Override + public Optional> visit( + LogicalTypeAnnotation.DateLogicalTypeAnnotation dateLogicalType) { + return Optional.of(ParquetValueWriters.ints(desc)); + } + + @Override + public Optional> visit( + LogicalTypeAnnotation.TimeLogicalTypeAnnotation timeLogicalType) { + if (timeLogicalType.getUnit() == LogicalTypeAnnotation.TimeUnit.MICROS) { + return Optional.of(ParquetValueWriters.longs(desc)); + } + return Optional.empty(); + } + + @Override + public Optional> visit( + LogicalTypeAnnotation.TimestampLogicalTypeAnnotation timestampLogicalType) { + if (timestampLogicalType.getUnit() == LogicalTypeAnnotation.TimeUnit.MICROS) { + return Optional.of(ParquetValueWriters.longs(desc)); + } + return Optional.empty(); + } + + @Override + public Optional> visit( + LogicalTypeAnnotation.IntLogicalTypeAnnotation intLogicalType) { + int bitWidth = intLogicalType.getBitWidth(); + if (bitWidth <= 8) { + return Optional.of(ParquetValueWriters.tinyints(desc)); + } else if (bitWidth <= 16) { + return Optional.of(ParquetValueWriters.shorts(desc)); + } else if (bitWidth <= 32) { + return Optional.of(ParquetValueWriters.ints(desc)); + } else { + return Optional.of(ParquetValueWriters.longs(desc)); + } + } + + @Override + public Optional> visit( + LogicalTypeAnnotation.BsonLogicalTypeAnnotation bsonLogicalType) { + return Optional.of(byteArrays(desc)); + } + } + + @Override + public ParquetValueWriter primitive(DataType sType, PrimitiveType primitive) { + ColumnDescriptor desc = type.getColumnDescription(currentPath()); + LogicalTypeAnnotation logicalTypeAnnotation = primitive.getLogicalTypeAnnotation(); + + if (logicalTypeAnnotation != null) { + return logicalTypeAnnotation + .accept(new LogicalTypeAnnotationParquetValueWriterVisitor(desc, primitive)) + .orElseThrow( + () -> + new UnsupportedOperationException( + "Unsupported logical type: " + primitive.getLogicalTypeAnnotation())); + } + + switch (primitive.getPrimitiveTypeName()) { + case FIXED_LEN_BYTE_ARRAY: + case BINARY: + if (LogicalTypeAnnotation.uuidType().equals(primitive.getLogicalTypeAnnotation())) { + return uuids(desc); + } + return byteArrays(desc); + case BOOLEAN: + return ParquetValueWriters.booleans(desc); + case INT32: + return ints(sType, desc); + case INT64: + return ParquetValueWriters.longs(desc); + case FLOAT: + return ParquetValueWriters.floats(desc); + case DOUBLE: + return ParquetValueWriters.doubles(desc); + default: + throw new UnsupportedOperationException("Unsupported type: " + primitive); + } + } + } + + private static PrimitiveWriter ints(DataType type, ColumnDescriptor desc) { + if (type instanceof ByteType) { + return ParquetValueWriters.tinyints(desc); + } else if (type instanceof ShortType) { + return ParquetValueWriters.shorts(desc); + } + return ParquetValueWriters.ints(desc); + } + + private static PrimitiveWriter utf8Strings(ColumnDescriptor desc) { + return new UTF8StringWriter(desc); + } + + private static PrimitiveWriter uuids(ColumnDescriptor desc) { + return new UUIDWriter(desc); + } + + private static PrimitiveWriter decimalAsInteger( + ColumnDescriptor desc, int precision, int scale) { + return new IntegerDecimalWriter(desc, precision, scale); + } + + private static PrimitiveWriter decimalAsLong( + ColumnDescriptor desc, int precision, int scale) { + return new LongDecimalWriter(desc, precision, scale); + } + + private static PrimitiveWriter decimalAsFixed( + ColumnDescriptor desc, int precision, int scale) { + return new FixedDecimalWriter(desc, precision, scale); + } + + private static PrimitiveWriter byteArrays(ColumnDescriptor desc) { + return new ByteArrayWriter(desc); + } + + private static class UTF8StringWriter extends PrimitiveWriter { + private UTF8StringWriter(ColumnDescriptor desc) { + super(desc); + } + + @Override + public void write(int repetitionLevel, UTF8String value) { + column.writeBinary(repetitionLevel, Binary.fromReusedByteArray(value.getBytes())); + } + } + + private static class IntegerDecimalWriter extends PrimitiveWriter { + private final int precision; + private final int scale; + + private IntegerDecimalWriter(ColumnDescriptor desc, int precision, int scale) { + super(desc); + this.precision = precision; + this.scale = scale; + } + + @Override + public void write(int repetitionLevel, Decimal decimal) { + Preconditions.checkArgument( + decimal.scale() == scale, + "Cannot write value as decimal(%s,%s), wrong scale: %s", + precision, + scale, + decimal); + Preconditions.checkArgument( + decimal.precision() <= precision, + "Cannot write value as decimal(%s,%s), too large: %s", + precision, + scale, + decimal); + + column.writeInteger(repetitionLevel, (int) decimal.toUnscaledLong()); + } + } + + private static class LongDecimalWriter extends PrimitiveWriter { + private final int precision; + private final int scale; + + private LongDecimalWriter(ColumnDescriptor desc, int precision, int scale) { + super(desc); + this.precision = precision; + this.scale = scale; + } + + @Override + public void write(int repetitionLevel, Decimal decimal) { + Preconditions.checkArgument( + decimal.scale() == scale, + "Cannot write value as decimal(%s,%s), wrong scale: %s", + precision, + scale, + decimal); + Preconditions.checkArgument( + decimal.precision() <= precision, + "Cannot write value as decimal(%s,%s), too large: %s", + precision, + scale, + decimal); + + column.writeLong(repetitionLevel, decimal.toUnscaledLong()); + } + } + + private static class FixedDecimalWriter extends PrimitiveWriter { + private final int precision; + private final int scale; + private final ThreadLocal bytes; + + private FixedDecimalWriter(ColumnDescriptor desc, int precision, int scale) { + super(desc); + this.precision = precision; + this.scale = scale; + this.bytes = + ThreadLocal.withInitial(() -> new byte[TypeUtil.decimalRequiredBytes(precision)]); + } + + @Override + public void write(int repetitionLevel, Decimal decimal) { + byte[] binary = + DecimalUtil.toReusedFixLengthBytes( + precision, scale, decimal.toJavaBigDecimal(), bytes.get()); + column.writeBinary(repetitionLevel, Binary.fromReusedByteArray(binary)); + } + } + + private static class UUIDWriter extends PrimitiveWriter { + private static final ThreadLocal BUFFER = + ThreadLocal.withInitial( + () -> { + ByteBuffer buffer = ByteBuffer.allocate(16); + buffer.order(ByteOrder.BIG_ENDIAN); + return buffer; + }); + + private UUIDWriter(ColumnDescriptor desc) { + super(desc); + } + + @Override + public void write(int repetitionLevel, UTF8String string) { + UUID uuid = UUID.fromString(string.toString()); + ByteBuffer buffer = UUIDUtil.convertToByteBuffer(uuid, BUFFER.get()); + column.writeBinary(repetitionLevel, Binary.fromReusedByteBuffer(buffer)); + } + } + + private static class ByteArrayWriter extends PrimitiveWriter { + private ByteArrayWriter(ColumnDescriptor desc) { + super(desc); + } + + @Override + public void write(int repetitionLevel, byte[] bytes) { + column.writeBinary(repetitionLevel, Binary.fromReusedByteArray(bytes)); + } + } + + private static class ArrayDataWriter extends RepeatedWriter { + private final DataType elementType; + + private ArrayDataWriter( + int definitionLevel, + int repetitionLevel, + ParquetValueWriter writer, + DataType elementType) { + super(definitionLevel, repetitionLevel, writer); + this.elementType = elementType; + } + + @Override + protected Iterator elements(ArrayData list) { + return new ElementIterator<>(list); + } + + private class ElementIterator implements Iterator { + private final int size; + private final ArrayData list; + private int index; + + private ElementIterator(ArrayData list) { + this.list = list; + size = list.numElements(); + index = 0; + } + + @Override + public boolean hasNext() { + return index != size; + } + + @Override + @SuppressWarnings("unchecked") + public E next() { + if (index >= size) { + throw new NoSuchElementException(); + } + + E element; + if (list.isNullAt(index)) { + element = null; + } else { + element = (E) list.get(index, elementType); + } + + index += 1; + + return element; + } + } + } + + private static class MapDataWriter extends RepeatedKeyValueWriter { + private final DataType keyType; + private final DataType valueType; + + private MapDataWriter( + int definitionLevel, + int repetitionLevel, + ParquetValueWriter keyWriter, + ParquetValueWriter valueWriter, + DataType keyType, + DataType valueType) { + super(definitionLevel, repetitionLevel, keyWriter, valueWriter); + this.keyType = keyType; + this.valueType = valueType; + } + + @Override + protected Iterator> pairs(MapData map) { + return new EntryIterator<>(map); + } + + private class EntryIterator implements Iterator> { + private final int size; + private final ArrayData keys; + private final ArrayData values; + private final ReusableEntry entry; + private int index; + + private EntryIterator(MapData map) { + size = map.numElements(); + keys = map.keyArray(); + values = map.valueArray(); + entry = new ReusableEntry<>(); + index = 0; + } + + @Override + public boolean hasNext() { + return index != size; + } + + @Override + @SuppressWarnings("unchecked") + public Map.Entry next() { + if (index >= size) { + throw new NoSuchElementException(); + } + + if (values.isNullAt(index)) { + entry.set((K) keys.get(index, keyType), null); + } else { + entry.set((K) keys.get(index, keyType), (V) values.get(index, valueType)); + } + + index += 1; + + return entry; + } + } + } + + private static class InternalRowWriter extends ParquetValueWriters.StructWriter { + private final DataType[] types; + + private InternalRowWriter(List> writers, List types) { + super(writers); + this.types = types.toArray(new DataType[0]); + } + + @Override + protected Object get(InternalRow struct, int index) { + return struct.get(index, types[index]); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkPlannedAvroReader.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkPlannedAvroReader.java new file mode 100644 index 000000000000..dc4af24685b3 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkPlannedAvroReader.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; +import org.apache.avro.LogicalType; +import org.apache.avro.LogicalTypes; +import org.apache.avro.Schema; +import org.apache.avro.io.DatumReader; +import org.apache.avro.io.Decoder; +import org.apache.iceberg.avro.AvroWithPartnerVisitor; +import org.apache.iceberg.avro.SupportsRowPosition; +import org.apache.iceberg.avro.ValueReader; +import org.apache.iceberg.avro.ValueReaders; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.Pair; +import org.apache.spark.sql.catalyst.InternalRow; + +public class SparkPlannedAvroReader implements DatumReader, SupportsRowPosition { + + private final Types.StructType expectedType; + private final Map idToConstant; + private ValueReader reader; + + public static SparkPlannedAvroReader create(org.apache.iceberg.Schema schema) { + return create(schema, ImmutableMap.of()); + } + + public static SparkPlannedAvroReader create( + org.apache.iceberg.Schema schema, Map constants) { + return new SparkPlannedAvroReader(schema, constants); + } + + private SparkPlannedAvroReader( + org.apache.iceberg.Schema expectedSchema, Map constants) { + this.expectedType = expectedSchema.asStruct(); + this.idToConstant = constants; + } + + @Override + @SuppressWarnings("unchecked") + public void setSchema(Schema fileSchema) { + this.reader = + (ValueReader) + AvroWithPartnerVisitor.visit( + expectedType, + fileSchema, + new ReadBuilder(idToConstant), + AvroWithPartnerVisitor.FieldIDAccessors.get()); + } + + @Override + public InternalRow read(InternalRow reuse, Decoder decoder) throws IOException { + return reader.read(decoder, reuse); + } + + @Override + public void setRowPositionSupplier(Supplier posSupplier) { + if (reader instanceof SupportsRowPosition) { + ((SupportsRowPosition) reader).setRowPositionSupplier(posSupplier); + } + } + + private static class ReadBuilder extends AvroWithPartnerVisitor> { + private final Map idToConstant; + + private ReadBuilder(Map idToConstant) { + this.idToConstant = idToConstant; + } + + @Override + public ValueReader record(Type partner, Schema record, List> fieldReaders) { + if (partner == null) { + return ValueReaders.skipStruct(fieldReaders); + } + + Types.StructType expected = partner.asStructType(); + List>> readPlan = + ValueReaders.buildReadPlan(expected, record, fieldReaders, idToConstant); + + // TODO: should this pass expected so that struct.get can reuse containers? + return SparkValueReaders.struct(readPlan, expected.fields().size()); + } + + @Override + public ValueReader union(Type partner, Schema union, List> options) { + return ValueReaders.union(options); + } + + @Override + public ValueReader array(Type partner, Schema array, ValueReader elementReader) { + return SparkValueReaders.array(elementReader); + } + + @Override + public ValueReader arrayMap( + Type partner, Schema map, ValueReader keyReader, ValueReader valueReader) { + return SparkValueReaders.arrayMap(keyReader, valueReader); + } + + @Override + public ValueReader map(Type partner, Schema map, ValueReader valueReader) { + return SparkValueReaders.map(SparkValueReaders.strings(), valueReader); + } + + @Override + public ValueReader primitive(Type partner, Schema primitive) { + LogicalType logicalType = primitive.getLogicalType(); + if (logicalType != null) { + switch (logicalType.getName()) { + case "date": + // Spark uses the same representation + return ValueReaders.ints(); + + case "timestamp-millis": + // adjust to microseconds + ValueReader longs = ValueReaders.longs(); + return (ValueReader) (decoder, ignored) -> longs.read(decoder, null) * 1000L; + + case "timestamp-micros": + // Spark uses the same representation + return ValueReaders.longs(); + + case "decimal": + return SparkValueReaders.decimal( + ValueReaders.decimalBytesReader(primitive), + ((LogicalTypes.Decimal) logicalType).getScale()); + + case "uuid": + return SparkValueReaders.uuids(); + + default: + throw new IllegalArgumentException("Unknown logical type: " + logicalType); + } + } + + switch (primitive.getType()) { + case NULL: + return ValueReaders.nulls(); + case BOOLEAN: + return ValueReaders.booleans(); + case INT: + if (partner != null && partner.typeId() == Type.TypeID.LONG) { + return ValueReaders.intsAsLongs(); + } + return ValueReaders.ints(); + case LONG: + return ValueReaders.longs(); + case FLOAT: + if (partner != null && partner.typeId() == Type.TypeID.DOUBLE) { + return ValueReaders.floatsAsDoubles(); + } + return ValueReaders.floats(); + case DOUBLE: + return ValueReaders.doubles(); + case STRING: + return SparkValueReaders.strings(); + case FIXED: + return ValueReaders.fixed(primitive.getFixedSize()); + case BYTES: + return ValueReaders.bytes(); + case ENUM: + return SparkValueReaders.enums(primitive.getEnumSymbols()); + default: + throw new IllegalArgumentException("Unsupported type: " + primitive); + } + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java new file mode 100644 index 000000000000..7e65535f5ecb --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java @@ -0,0 +1,322 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import java.io.IOException; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; +import org.apache.avro.io.Decoder; +import org.apache.avro.util.Utf8; +import org.apache.iceberg.avro.ValueReader; +import org.apache.iceberg.avro.ValueReaders; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.Pair; +import org.apache.iceberg.util.UUIDUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.GenericArrayData; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.UTF8String; + +public class SparkValueReaders { + + private SparkValueReaders() {} + + static ValueReader strings() { + return StringReader.INSTANCE; + } + + static ValueReader enums(List symbols) { + return new EnumReader(symbols); + } + + static ValueReader uuids() { + return UUIDReader.INSTANCE; + } + + static ValueReader decimal(ValueReader unscaledReader, int scale) { + return new DecimalReader(unscaledReader, scale); + } + + static ValueReader array(ValueReader elementReader) { + return new ArrayReader(elementReader); + } + + static ValueReader arrayMap( + ValueReader keyReader, ValueReader valueReader) { + return new ArrayMapReader(keyReader, valueReader); + } + + static ValueReader map(ValueReader keyReader, ValueReader valueReader) { + return new MapReader(keyReader, valueReader); + } + + static ValueReader struct( + List>> readPlan, int numFields) { + return new PlannedStructReader(readPlan, numFields); + } + + static ValueReader struct( + List> readers, Types.StructType struct, Map idToConstant) { + return new StructReader(readers, struct, idToConstant); + } + + private static class StringReader implements ValueReader { + private static final StringReader INSTANCE = new StringReader(); + + private StringReader() {} + + @Override + public UTF8String read(Decoder decoder, Object reuse) throws IOException { + // use the decoder's readString(Utf8) method because it may be a resolving decoder + Utf8 utf8 = null; + if (reuse instanceof UTF8String) { + utf8 = new Utf8(((UTF8String) reuse).getBytes()); + } + + Utf8 string = decoder.readString(utf8); + return UTF8String.fromBytes(string.getBytes(), 0, string.getByteLength()); + } + } + + private static class EnumReader implements ValueReader { + private final UTF8String[] symbols; + + private EnumReader(List symbols) { + this.symbols = new UTF8String[symbols.size()]; + for (int i = 0; i < this.symbols.length; i += 1) { + this.symbols[i] = UTF8String.fromBytes(symbols.get(i).getBytes(StandardCharsets.UTF_8)); + } + } + + @Override + public UTF8String read(Decoder decoder, Object ignore) throws IOException { + int index = decoder.readEnum(); + return symbols[index]; + } + } + + private static class UUIDReader implements ValueReader { + private static final ThreadLocal BUFFER = + ThreadLocal.withInitial( + () -> { + ByteBuffer buffer = ByteBuffer.allocate(16); + buffer.order(ByteOrder.BIG_ENDIAN); + return buffer; + }); + + private static final UUIDReader INSTANCE = new UUIDReader(); + + private UUIDReader() {} + + @Override + @SuppressWarnings("ByteBufferBackingArray") + public UTF8String read(Decoder decoder, Object reuse) throws IOException { + ByteBuffer buffer = BUFFER.get(); + buffer.rewind(); + + decoder.readFixed(buffer.array(), 0, 16); + + return UTF8String.fromString(UUIDUtil.convert(buffer).toString()); + } + } + + private static class DecimalReader implements ValueReader { + private final ValueReader bytesReader; + private final int scale; + + private DecimalReader(ValueReader bytesReader, int scale) { + this.bytesReader = bytesReader; + this.scale = scale; + } + + @Override + public Decimal read(Decoder decoder, Object reuse) throws IOException { + byte[] bytes = bytesReader.read(decoder, null); + return Decimal.apply(new BigDecimal(new BigInteger(bytes), scale)); + } + } + + private static class ArrayReader implements ValueReader { + private final ValueReader elementReader; + private final List reusedList = Lists.newArrayList(); + + private ArrayReader(ValueReader elementReader) { + this.elementReader = elementReader; + } + + @Override + public GenericArrayData read(Decoder decoder, Object reuse) throws IOException { + reusedList.clear(); + long chunkLength = decoder.readArrayStart(); + + while (chunkLength > 0) { + for (int i = 0; i < chunkLength; i += 1) { + reusedList.add(elementReader.read(decoder, null)); + } + + chunkLength = decoder.arrayNext(); + } + + // this will convert the list to an array so it is okay to reuse the list + return new GenericArrayData(reusedList.toArray()); + } + } + + private static class ArrayMapReader implements ValueReader { + private final ValueReader keyReader; + private final ValueReader valueReader; + + private final List reusedKeyList = Lists.newArrayList(); + private final List reusedValueList = Lists.newArrayList(); + + private ArrayMapReader(ValueReader keyReader, ValueReader valueReader) { + this.keyReader = keyReader; + this.valueReader = valueReader; + } + + @Override + public ArrayBasedMapData read(Decoder decoder, Object reuse) throws IOException { + reusedKeyList.clear(); + reusedValueList.clear(); + + long chunkLength = decoder.readArrayStart(); + + while (chunkLength > 0) { + for (int i = 0; i < chunkLength; i += 1) { + reusedKeyList.add(keyReader.read(decoder, null)); + reusedValueList.add(valueReader.read(decoder, null)); + } + + chunkLength = decoder.arrayNext(); + } + + return new ArrayBasedMapData( + new GenericArrayData(reusedKeyList.toArray()), + new GenericArrayData(reusedValueList.toArray())); + } + } + + private static class MapReader implements ValueReader { + private final ValueReader keyReader; + private final ValueReader valueReader; + + private final List reusedKeyList = Lists.newArrayList(); + private final List reusedValueList = Lists.newArrayList(); + + private MapReader(ValueReader keyReader, ValueReader valueReader) { + this.keyReader = keyReader; + this.valueReader = valueReader; + } + + @Override + public ArrayBasedMapData read(Decoder decoder, Object reuse) throws IOException { + reusedKeyList.clear(); + reusedValueList.clear(); + + long chunkLength = decoder.readMapStart(); + + while (chunkLength > 0) { + for (int i = 0; i < chunkLength; i += 1) { + reusedKeyList.add(keyReader.read(decoder, null)); + reusedValueList.add(valueReader.read(decoder, null)); + } + + chunkLength = decoder.mapNext(); + } + + return new ArrayBasedMapData( + new GenericArrayData(reusedKeyList.toArray()), + new GenericArrayData(reusedValueList.toArray())); + } + } + + static class PlannedStructReader extends ValueReaders.PlannedStructReader { + private final int numFields; + + protected PlannedStructReader(List>> readPlan, int numFields) { + super(readPlan); + this.numFields = numFields; + } + + @Override + protected InternalRow reuseOrCreate(Object reuse) { + if (reuse instanceof GenericInternalRow + && ((GenericInternalRow) reuse).numFields() == numFields) { + return (InternalRow) reuse; + } + return new GenericInternalRow(numFields); + } + + @Override + protected Object get(InternalRow struct, int pos) { + return null; + } + + @Override + protected void set(InternalRow struct, int pos, Object value) { + if (value != null) { + struct.update(pos, value); + } else { + struct.setNullAt(pos); + } + } + } + + static class StructReader extends ValueReaders.StructReader { + private final int numFields; + + protected StructReader( + List> readers, Types.StructType struct, Map idToConstant) { + super(readers, struct, idToConstant); + this.numFields = readers.size(); + } + + @Override + protected InternalRow reuseOrCreate(Object reuse) { + if (reuse instanceof GenericInternalRow + && ((GenericInternalRow) reuse).numFields() == numFields) { + return (InternalRow) reuse; + } + return new GenericInternalRow(numFields); + } + + @Override + protected Object get(InternalRow struct, int pos) { + return null; + } + + @Override + protected void set(InternalRow struct, int pos, Object value) { + if (value != null) { + struct.update(pos, value); + } else { + struct.setNullAt(pos); + } + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueWriters.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueWriters.java new file mode 100644 index 000000000000..bb8218bd83df --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueWriters.java @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import java.io.IOException; +import java.lang.reflect.Array; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.List; +import java.util.UUID; +import org.apache.avro.io.Encoder; +import org.apache.avro.util.Utf8; +import org.apache.iceberg.avro.ValueWriter; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.util.DecimalUtil; +import org.apache.iceberg.util.UUIDUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.UTF8String; + +public class SparkValueWriters { + + private SparkValueWriters() {} + + static ValueWriter strings() { + return StringWriter.INSTANCE; + } + + static ValueWriter uuids() { + return UUIDWriter.INSTANCE; + } + + static ValueWriter decimal(int precision, int scale) { + return new DecimalWriter(precision, scale); + } + + static ValueWriter array(ValueWriter elementWriter, DataType elementType) { + return new ArrayWriter<>(elementWriter, elementType); + } + + static ValueWriter arrayMap( + ValueWriter keyWriter, DataType keyType, ValueWriter valueWriter, DataType valueType) { + return new ArrayMapWriter<>(keyWriter, keyType, valueWriter, valueType); + } + + static ValueWriter map( + ValueWriter keyWriter, DataType keyType, ValueWriter valueWriter, DataType valueType) { + return new MapWriter<>(keyWriter, keyType, valueWriter, valueType); + } + + static ValueWriter struct(List> writers, List types) { + return new StructWriter(writers, types); + } + + private static class StringWriter implements ValueWriter { + private static final StringWriter INSTANCE = new StringWriter(); + + private StringWriter() {} + + @Override + public void write(UTF8String s, Encoder encoder) throws IOException { + // use getBytes because it may return the backing byte array if available. + // otherwise, it copies to a new byte array, which is still cheaper than Avro + // calling toString, which incurs encoding costs + encoder.writeString(new Utf8(s.getBytes())); + } + } + + private static class UUIDWriter implements ValueWriter { + private static final ThreadLocal BUFFER = + ThreadLocal.withInitial( + () -> { + ByteBuffer buffer = ByteBuffer.allocate(16); + buffer.order(ByteOrder.BIG_ENDIAN); + return buffer; + }); + + private static final UUIDWriter INSTANCE = new UUIDWriter(); + + private UUIDWriter() {} + + @Override + @SuppressWarnings("ByteBufferBackingArray") + public void write(UTF8String s, Encoder encoder) throws IOException { + // TODO: direct conversion from string to byte buffer + UUID uuid = UUID.fromString(s.toString()); + // calling array() is safe because the buffer is always allocated by the thread-local + encoder.writeFixed(UUIDUtil.convertToByteBuffer(uuid, BUFFER.get()).array()); + } + } + + private static class DecimalWriter implements ValueWriter { + private final int precision; + private final int scale; + private final ThreadLocal bytes; + + private DecimalWriter(int precision, int scale) { + this.precision = precision; + this.scale = scale; + this.bytes = + ThreadLocal.withInitial(() -> new byte[TypeUtil.decimalRequiredBytes(precision)]); + } + + @Override + public void write(Decimal d, Encoder encoder) throws IOException { + encoder.writeFixed( + DecimalUtil.toReusedFixLengthBytes(precision, scale, d.toJavaBigDecimal(), bytes.get())); + } + } + + private static class ArrayWriter implements ValueWriter { + private final ValueWriter elementWriter; + private final DataType elementType; + + private ArrayWriter(ValueWriter elementWriter, DataType elementType) { + this.elementWriter = elementWriter; + this.elementType = elementType; + } + + @Override + @SuppressWarnings("unchecked") + public void write(ArrayData array, Encoder encoder) throws IOException { + encoder.writeArrayStart(); + int numElements = array.numElements(); + encoder.setItemCount(numElements); + for (int i = 0; i < numElements; i += 1) { + encoder.startItem(); + elementWriter.write((T) array.get(i, elementType), encoder); + } + encoder.writeArrayEnd(); + } + } + + private static class ArrayMapWriter implements ValueWriter { + private final ValueWriter keyWriter; + private final ValueWriter valueWriter; + private final DataType keyType; + private final DataType valueType; + + private ArrayMapWriter( + ValueWriter keyWriter, + DataType keyType, + ValueWriter valueWriter, + DataType valueType) { + this.keyWriter = keyWriter; + this.keyType = keyType; + this.valueWriter = valueWriter; + this.valueType = valueType; + } + + @Override + @SuppressWarnings("unchecked") + public void write(MapData map, Encoder encoder) throws IOException { + encoder.writeArrayStart(); + int numElements = map.numElements(); + encoder.setItemCount(numElements); + ArrayData keyArray = map.keyArray(); + ArrayData valueArray = map.valueArray(); + for (int i = 0; i < numElements; i += 1) { + encoder.startItem(); + keyWriter.write((K) keyArray.get(i, keyType), encoder); + valueWriter.write((V) valueArray.get(i, valueType), encoder); + } + encoder.writeArrayEnd(); + } + } + + private static class MapWriter implements ValueWriter { + private final ValueWriter keyWriter; + private final ValueWriter valueWriter; + private final DataType keyType; + private final DataType valueType; + + private MapWriter( + ValueWriter keyWriter, + DataType keyType, + ValueWriter valueWriter, + DataType valueType) { + this.keyWriter = keyWriter; + this.keyType = keyType; + this.valueWriter = valueWriter; + this.valueType = valueType; + } + + @Override + @SuppressWarnings("unchecked") + public void write(MapData map, Encoder encoder) throws IOException { + encoder.writeMapStart(); + int numElements = map.numElements(); + encoder.setItemCount(numElements); + ArrayData keyArray = map.keyArray(); + ArrayData valueArray = map.valueArray(); + for (int i = 0; i < numElements; i += 1) { + encoder.startItem(); + keyWriter.write((K) keyArray.get(i, keyType), encoder); + valueWriter.write((V) valueArray.get(i, valueType), encoder); + } + encoder.writeMapEnd(); + } + } + + static class StructWriter implements ValueWriter { + private final ValueWriter[] writers; + private final DataType[] types; + + @SuppressWarnings("unchecked") + private StructWriter(List> writers, List types) { + this.writers = (ValueWriter[]) Array.newInstance(ValueWriter.class, writers.size()); + this.types = new DataType[writers.size()]; + for (int i = 0; i < writers.size(); i += 1) { + this.writers[i] = writers.get(i); + this.types[i] = types.get(i); + } + } + + ValueWriter[] writers() { + return writers; + } + + @Override + public void write(InternalRow row, Encoder encoder) throws IOException { + for (int i = 0; i < types.length; i += 1) { + if (row.isNullAt(i)) { + writers[i].write(null, encoder); + } else { + write(row, i, writers[i], encoder); + } + } + } + + @SuppressWarnings("unchecked") + private void write(InternalRow row, int pos, ValueWriter writer, Encoder encoder) + throws IOException { + writer.write((T) row.get(pos, types[pos]), encoder); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ArrowVectorAccessorFactory.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ArrowVectorAccessorFactory.java new file mode 100644 index 000000000000..29e938bb092e --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ArrowVectorAccessorFactory.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.vectorized; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.vector.FixedSizeBinaryVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.iceberg.arrow.vectorized.GenericArrowVectorAccessorFactory; +import org.apache.iceberg.util.UUIDUtil; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.vectorized.ArrowColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.unsafe.types.UTF8String; + +final class ArrowVectorAccessorFactory + extends GenericArrowVectorAccessorFactory< + Decimal, UTF8String, ColumnarArray, ArrowColumnVector> { + + ArrowVectorAccessorFactory() { + super( + DecimalFactoryImpl::new, + StringFactoryImpl::new, + StructChildFactoryImpl::new, + ArrayFactoryImpl::new); + } + + private static final class DecimalFactoryImpl implements DecimalFactory { + @Override + public Class getGenericClass() { + return Decimal.class; + } + + @Override + public Decimal ofLong(long value, int precision, int scale) { + return Decimal.apply(value, precision, scale); + } + + @Override + public Decimal ofBigDecimal(BigDecimal value, int precision, int scale) { + return Decimal.apply(value, precision, scale); + } + } + + private static final class StringFactoryImpl implements StringFactory { + @Override + public Class getGenericClass() { + return UTF8String.class; + } + + @Override + public UTF8String ofRow(VarCharVector vector, int rowId) { + int start = vector.getStartOffset(rowId); + int end = vector.getEndOffset(rowId); + + return UTF8String.fromAddress( + null, vector.getDataBuffer().memoryAddress() + start, end - start); + } + + @Override + public UTF8String ofRow(FixedSizeBinaryVector vector, int rowId) { + return UTF8String.fromString(UUIDUtil.convert(vector.get(rowId)).toString()); + } + + @Override + public UTF8String ofBytes(byte[] bytes) { + return UTF8String.fromBytes(bytes); + } + + @Override + public UTF8String ofByteBuffer(ByteBuffer byteBuffer) { + if (byteBuffer.hasArray()) { + return UTF8String.fromBytes( + byteBuffer.array(), + byteBuffer.arrayOffset() + byteBuffer.position(), + byteBuffer.remaining()); + } + byte[] bytes = new byte[byteBuffer.remaining()]; + byteBuffer.get(bytes); + return UTF8String.fromBytes(bytes); + } + } + + private static final class ArrayFactoryImpl + implements ArrayFactory { + @Override + public ArrowColumnVector ofChild(ValueVector childVector) { + return new ArrowColumnVector(childVector); + } + + @Override + public ColumnarArray ofRow(ValueVector vector, ArrowColumnVector childData, int rowId) { + ArrowBuf offsets = vector.getOffsetBuffer(); + int index = rowId * ListVector.OFFSET_WIDTH; + int start = offsets.getInt(index); + int end = offsets.getInt(index + ListVector.OFFSET_WIDTH); + return new ColumnarArray(childData, start, end - start); + } + } + + private static final class StructChildFactoryImpl + implements StructChildFactory { + @Override + public Class getGenericClass() { + return ArrowColumnVector.class; + } + + @Override + public ArrowColumnVector of(ValueVector childVector) { + return new ArrowColumnVector(childVector); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ArrowVectorAccessors.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ArrowVectorAccessors.java new file mode 100644 index 000000000000..4e02dafb3c13 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ArrowVectorAccessors.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.vectorized; + +import org.apache.iceberg.arrow.vectorized.ArrowVectorAccessor; +import org.apache.iceberg.arrow.vectorized.VectorHolder; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.vectorized.ArrowColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.unsafe.types.UTF8String; + +public class ArrowVectorAccessors { + + private static final ArrowVectorAccessorFactory FACTORY = new ArrowVectorAccessorFactory(); + + static ArrowVectorAccessor + getVectorAccessor(VectorHolder holder) { + return FACTORY.getVectorAccessor(holder); + } + + private ArrowVectorAccessors() {} +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ColumnVectorBuilder.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ColumnVectorBuilder.java new file mode 100644 index 000000000000..cce30fd1c7f6 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ColumnVectorBuilder.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.vectorized; + +import org.apache.iceberg.arrow.vectorized.VectorHolder; +import org.apache.iceberg.arrow.vectorized.VectorHolder.ConstantVectorHolder; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.vectorized.ColumnVector; + +class ColumnVectorBuilder { + private boolean[] isDeleted; + private int[] rowIdMapping; + + public ColumnVectorBuilder withDeletedRows(int[] rowIdMappingArray, boolean[] isDeletedArray) { + this.rowIdMapping = rowIdMappingArray; + this.isDeleted = isDeletedArray; + return this; + } + + public ColumnVector build(VectorHolder holder, int numRows) { + if (holder.isDummy()) { + if (holder instanceof VectorHolder.DeletedVectorHolder) { + return new DeletedColumnVector(Types.BooleanType.get(), isDeleted); + } else if (holder instanceof ConstantVectorHolder) { + ConstantVectorHolder constantHolder = (ConstantVectorHolder) holder; + Type icebergType = constantHolder.icebergType(); + Object value = constantHolder.getConstant(); + return new ConstantColumnVector(icebergType, numRows, value); + } else { + throw new IllegalStateException("Unknown dummy vector holder: " + holder); + } + } else if (rowIdMapping != null) { + return new ColumnVectorWithFilter(holder, rowIdMapping); + } else { + return new IcebergArrowColumnVector(holder); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ColumnVectorWithFilter.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ColumnVectorWithFilter.java new file mode 100644 index 000000000000..ab0d652321d3 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ColumnVectorWithFilter.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.vectorized; + +import org.apache.iceberg.arrow.vectorized.VectorHolder; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.unsafe.types.UTF8String; + +public class ColumnVectorWithFilter extends IcebergArrowColumnVector { + private final int[] rowIdMapping; + + public ColumnVectorWithFilter(VectorHolder holder, int[] rowIdMapping) { + super(holder); + this.rowIdMapping = rowIdMapping; + } + + @Override + public boolean isNullAt(int rowId) { + return nullabilityHolder().isNullAt(rowIdMapping[rowId]) == 1; + } + + @Override + public boolean getBoolean(int rowId) { + return accessor().getBoolean(rowIdMapping[rowId]); + } + + @Override + public int getInt(int rowId) { + return accessor().getInt(rowIdMapping[rowId]); + } + + @Override + public long getLong(int rowId) { + return accessor().getLong(rowIdMapping[rowId]); + } + + @Override + public float getFloat(int rowId) { + return accessor().getFloat(rowIdMapping[rowId]); + } + + @Override + public double getDouble(int rowId) { + return accessor().getDouble(rowIdMapping[rowId]); + } + + @Override + public ColumnarArray getArray(int rowId) { + if (isNullAt(rowId)) { + return null; + } + return accessor().getArray(rowIdMapping[rowId]); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) { + return null; + } + return accessor().getDecimal(rowIdMapping[rowId], precision, scale); + } + + @Override + public UTF8String getUTF8String(int rowId) { + if (isNullAt(rowId)) { + return null; + } + return accessor().getUTF8String(rowIdMapping[rowId]); + } + + @Override + public byte[] getBinary(int rowId) { + if (isNullAt(rowId)) { + return null; + } + return accessor().getBinary(rowIdMapping[rowId]); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ColumnarBatchReader.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ColumnarBatchReader.java new file mode 100644 index 000000000000..f07d8c545e35 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ColumnarBatchReader.java @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.vectorized; + +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.arrow.vectorized.BaseBatchReader; +import org.apache.iceberg.arrow.vectorized.VectorizedArrowReader; +import org.apache.iceberg.arrow.vectorized.VectorizedArrowReader.DeletedVectorReader; +import org.apache.iceberg.data.DeleteFilter; +import org.apache.iceberg.deletes.PositionDeleteIndex; +import org.apache.iceberg.parquet.VectorizedReader; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.util.Pair; +import org.apache.parquet.column.page.PageReadStore; +import org.apache.parquet.hadoop.metadata.ColumnChunkMetaData; +import org.apache.parquet.hadoop.metadata.ColumnPath; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +/** + * {@link VectorizedReader} that returns Spark's {@link ColumnarBatch} to support Spark's vectorized + * read path. The {@link ColumnarBatch} returned is created by passing in the Arrow vectors + * populated via delegated read calls to {@linkplain VectorizedArrowReader VectorReader(s)}. + */ +public class ColumnarBatchReader extends BaseBatchReader { + private final boolean hasIsDeletedColumn; + private DeleteFilter deletes = null; + private long rowStartPosInBatch = 0; + + public ColumnarBatchReader(List> readers) { + super(readers); + this.hasIsDeletedColumn = + readers.stream().anyMatch(reader -> reader instanceof DeletedVectorReader); + } + + @Override + public void setRowGroupInfo( + PageReadStore pageStore, Map metaData, long rowPosition) { + super.setRowGroupInfo(pageStore, metaData, rowPosition); + this.rowStartPosInBatch = rowPosition; + } + + public void setDeleteFilter(DeleteFilter deleteFilter) { + this.deletes = deleteFilter; + } + + @Override + public final ColumnarBatch read(ColumnarBatch reuse, int numRowsToRead) { + if (reuse == null) { + closeVectors(); + } + + ColumnarBatch columnarBatch = new ColumnBatchLoader(numRowsToRead).loadDataToColumnBatch(); + rowStartPosInBatch += numRowsToRead; + return columnarBatch; + } + + private class ColumnBatchLoader { + private final int numRowsToRead; + // the rowId mapping to skip deleted rows for all column vectors inside a batch, it is null when + // there is no deletes + private int[] rowIdMapping; + // the array to indicate if a row is deleted or not, it is null when there is no "_deleted" + // metadata column + private boolean[] isDeleted; + + ColumnBatchLoader(int numRowsToRead) { + Preconditions.checkArgument( + numRowsToRead > 0, "Invalid number of rows to read: %s", numRowsToRead); + this.numRowsToRead = numRowsToRead; + if (hasIsDeletedColumn) { + isDeleted = new boolean[numRowsToRead]; + } + } + + ColumnarBatch loadDataToColumnBatch() { + int numRowsUndeleted = initRowIdMapping(); + + ColumnVector[] arrowColumnVectors = readDataToColumnVectors(); + + ColumnarBatch newColumnarBatch = new ColumnarBatch(arrowColumnVectors); + newColumnarBatch.setNumRows(numRowsUndeleted); + + if (hasEqDeletes()) { + applyEqDelete(newColumnarBatch); + } + + if (hasIsDeletedColumn && rowIdMapping != null) { + // reset the row id mapping array, so that it doesn't filter out the deleted rows + for (int i = 0; i < numRowsToRead; i++) { + rowIdMapping[i] = i; + } + newColumnarBatch.setNumRows(numRowsToRead); + } + + return newColumnarBatch; + } + + ColumnVector[] readDataToColumnVectors() { + ColumnVector[] arrowColumnVectors = new ColumnVector[readers.length]; + + ColumnVectorBuilder columnVectorBuilder = new ColumnVectorBuilder(); + for (int i = 0; i < readers.length; i += 1) { + vectorHolders[i] = readers[i].read(vectorHolders[i], numRowsToRead); + int numRowsInVector = vectorHolders[i].numValues(); + Preconditions.checkState( + numRowsInVector == numRowsToRead, + "Number of rows in the vector %s didn't match expected %s ", + numRowsInVector, + numRowsToRead); + + arrowColumnVectors[i] = + columnVectorBuilder + .withDeletedRows(rowIdMapping, isDeleted) + .build(vectorHolders[i], numRowsInVector); + } + return arrowColumnVectors; + } + + boolean hasEqDeletes() { + return deletes != null && deletes.hasEqDeletes(); + } + + int initRowIdMapping() { + Pair posDeleteRowIdMapping = posDelRowIdMapping(); + if (posDeleteRowIdMapping != null) { + rowIdMapping = posDeleteRowIdMapping.first(); + return posDeleteRowIdMapping.second(); + } else { + rowIdMapping = initEqDeleteRowIdMapping(); + return numRowsToRead; + } + } + + Pair posDelRowIdMapping() { + if (deletes != null && deletes.hasPosDeletes()) { + return buildPosDelRowIdMapping(deletes.deletedRowPositions()); + } else { + return null; + } + } + + /** + * Build a row id mapping inside a batch, which skips deleted rows. Here is an example of how we + * delete 2 rows in a batch with 8 rows in total. [0,1,2,3,4,5,6,7] -- Original status of the + * row id mapping array [F,F,F,F,F,F,F,F] -- Original status of the isDeleted array Position + * delete 2, 6 [0,1,3,4,5,7,-,-] -- After applying position deletes [Set Num records to 6] + * [F,F,T,F,F,F,T,F] -- After applying position deletes + * + * @param deletedRowPositions a set of deleted row positions + * @return the mapping array and the new num of rows in a batch, null if no row is deleted + */ + Pair buildPosDelRowIdMapping(PositionDeleteIndex deletedRowPositions) { + if (deletedRowPositions == null) { + return null; + } + + int[] posDelRowIdMapping = new int[numRowsToRead]; + int originalRowId = 0; + int currentRowId = 0; + while (originalRowId < numRowsToRead) { + if (!deletedRowPositions.isDeleted(originalRowId + rowStartPosInBatch)) { + posDelRowIdMapping[currentRowId] = originalRowId; + currentRowId++; + } else { + if (hasIsDeletedColumn) { + isDeleted[originalRowId] = true; + } + + deletes.incrementDeleteCount(); + } + originalRowId++; + } + + if (currentRowId == numRowsToRead) { + // there is no delete in this batch + return null; + } else { + return Pair.of(posDelRowIdMapping, currentRowId); + } + } + + int[] initEqDeleteRowIdMapping() { + int[] eqDeleteRowIdMapping = null; + if (hasEqDeletes()) { + eqDeleteRowIdMapping = new int[numRowsToRead]; + for (int i = 0; i < numRowsToRead; i++) { + eqDeleteRowIdMapping[i] = i; + } + } + + return eqDeleteRowIdMapping; + } + + /** + * Filter out the equality deleted rows. Here is an example, [0,1,2,3,4,5,6,7] -- Original + * status of the row id mapping array [F,F,F,F,F,F,F,F] -- Original status of the isDeleted + * array Position delete 2, 6 [0,1,3,4,5,7,-,-] -- After applying position deletes [Set Num + * records to 6] [F,F,T,F,F,F,T,F] -- After applying position deletes Equality delete 1 <= x <= + * 3 [0,4,5,7,-,-,-,-] -- After applying equality deletes [Set Num records to 4] + * [F,T,T,T,F,F,T,F] -- After applying equality deletes + * + * @param columnarBatch the {@link ColumnarBatch} to apply the equality delete + */ + void applyEqDelete(ColumnarBatch columnarBatch) { + Iterator it = columnarBatch.rowIterator(); + int rowId = 0; + int currentRowId = 0; + while (it.hasNext()) { + InternalRow row = it.next(); + if (deletes.eqDeletedRowFilter().test(row)) { + // the row is NOT deleted + // skip deleted rows by pointing to the next undeleted row Id + rowIdMapping[currentRowId] = rowIdMapping[rowId]; + currentRowId++; + } else { + if (hasIsDeletedColumn) { + isDeleted[rowIdMapping[rowId]] = true; + } + + deletes.incrementDeleteCount(); + } + + rowId++; + } + + columnarBatch.setNumRows(currentRowId); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ConstantColumnVector.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ConstantColumnVector.java new file mode 100644 index 000000000000..1398a137c1c0 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ConstantColumnVector.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.vectorized; + +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.unsafe.types.UTF8String; + +class ConstantColumnVector extends ColumnVector { + + private final Type icebergType; + private final Object constant; + private final int batchSize; + + ConstantColumnVector(Type icebergType, int batchSize, Object constant) { + // the type may be unknown for NULL vectors + super(icebergType != null ? SparkSchemaUtil.convert(icebergType) : null); + this.icebergType = icebergType; + this.constant = constant; + this.batchSize = batchSize; + } + + @Override + public void close() {} + + @Override + public boolean hasNull() { + return constant == null; + } + + @Override + public int numNulls() { + return constant == null ? batchSize : 0; + } + + @Override + public boolean isNullAt(int rowId) { + return constant == null; + } + + @Override + public boolean getBoolean(int rowId) { + return (boolean) constant; + } + + @Override + public byte getByte(int rowId) { + return (byte) constant; + } + + @Override + public short getShort(int rowId) { + return (short) constant; + } + + @Override + public int getInt(int rowId) { + return (int) constant; + } + + @Override + public long getLong(int rowId) { + return (long) constant; + } + + @Override + public float getFloat(int rowId) { + return (float) constant; + } + + @Override + public double getDouble(int rowId) { + return (double) constant; + } + + @Override + public ColumnarArray getArray(int rowId) { + throw new UnsupportedOperationException(this.getClass() + " does not implement getArray"); + } + + @Override + public ColumnarMap getMap(int ordinal) { + throw new UnsupportedOperationException(this.getClass() + " does not implement getMap"); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + return (Decimal) constant; + } + + @Override + public UTF8String getUTF8String(int rowId) { + return (UTF8String) constant; + } + + @Override + public byte[] getBinary(int rowId) { + return (byte[]) constant; + } + + @Override + public ColumnVector getChild(int ordinal) { + InternalRow constantAsRow = (InternalRow) constant; + Object childConstant = constantAsRow.get(ordinal, childType(ordinal)); + return new ConstantColumnVector(childIcebergType(ordinal), batchSize, childConstant); + } + + private Type childIcebergType(int ordinal) { + Types.StructType icebergTypeAsStruct = (Types.StructType) icebergType; + return icebergTypeAsStruct.fields().get(ordinal).type(); + } + + private DataType childType(int ordinal) { + StructType typeAsStruct = (StructType) type; + return typeAsStruct.fields()[ordinal].dataType(); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/DeletedColumnVector.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/DeletedColumnVector.java new file mode 100644 index 000000000000..eec6ecb9ace4 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/DeletedColumnVector.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.vectorized; + +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.types.Type; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.unsafe.types.UTF8String; + +public class DeletedColumnVector extends ColumnVector { + private final boolean[] isDeleted; + + public DeletedColumnVector(Type type, boolean[] isDeleted) { + super(SparkSchemaUtil.convert(type)); + Preconditions.checkArgument(isDeleted != null, "Boolean array isDeleted cannot be null"); + this.isDeleted = isDeleted; + } + + @Override + public void close() {} + + @Override + public boolean hasNull() { + return false; + } + + @Override + public int numNulls() { + return 0; + } + + @Override + public boolean isNullAt(int rowId) { + return false; + } + + @Override + public boolean getBoolean(int rowId) { + return isDeleted[rowId]; + } + + @Override + public byte getByte(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public short getShort(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public int getInt(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public long getLong(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public float getFloat(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public double getDouble(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarArray getArray(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarMap getMap(int ordinal) { + throw new UnsupportedOperationException(); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + throw new UnsupportedOperationException(); + } + + @Override + public UTF8String getUTF8String(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public byte[] getBinary(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnVector getChild(int ordinal) { + throw new UnsupportedOperationException(); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/IcebergArrowColumnVector.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/IcebergArrowColumnVector.java new file mode 100644 index 000000000000..38ec3a0e838c --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/IcebergArrowColumnVector.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.vectorized; + +import org.apache.iceberg.arrow.vectorized.ArrowVectorAccessor; +import org.apache.iceberg.arrow.vectorized.NullabilityHolder; +import org.apache.iceberg.arrow.vectorized.VectorHolder; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.vectorized.ArrowColumnVector; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * Implementation of Spark's {@link ColumnVector} interface. The code for this class is heavily + * inspired from Spark's {@link ArrowColumnVector} The main difference is in how nullability checks + * are made in this class by relying on {@link NullabilityHolder} instead of the validity vector in + * the Arrow vector. + */ +public class IcebergArrowColumnVector extends ColumnVector { + + private final ArrowVectorAccessor accessor; + private final NullabilityHolder nullabilityHolder; + + public IcebergArrowColumnVector(VectorHolder holder) { + super(SparkSchemaUtil.convert(holder.icebergType())); + this.nullabilityHolder = holder.nullabilityHolder(); + this.accessor = ArrowVectorAccessors.getVectorAccessor(holder); + } + + protected ArrowVectorAccessor accessor() { + return accessor; + } + + protected NullabilityHolder nullabilityHolder() { + return nullabilityHolder; + } + + @Override + public void close() { + accessor.close(); + } + + @Override + public boolean hasNull() { + return nullabilityHolder.hasNulls(); + } + + @Override + public int numNulls() { + return nullabilityHolder.numNulls(); + } + + @Override + public boolean isNullAt(int rowId) { + return nullabilityHolder.isNullAt(rowId) == 1; + } + + @Override + public boolean getBoolean(int rowId) { + return accessor.getBoolean(rowId); + } + + @Override + public byte getByte(int rowId) { + throw new UnsupportedOperationException("Unsupported type - byte"); + } + + @Override + public short getShort(int rowId) { + throw new UnsupportedOperationException("Unsupported type - short"); + } + + @Override + public int getInt(int rowId) { + return accessor.getInt(rowId); + } + + @Override + public long getLong(int rowId) { + return accessor.getLong(rowId); + } + + @Override + public float getFloat(int rowId) { + return accessor.getFloat(rowId); + } + + @Override + public double getDouble(int rowId) { + return accessor.getDouble(rowId); + } + + @Override + public ColumnarArray getArray(int rowId) { + if (isNullAt(rowId)) { + return null; + } + return accessor.getArray(rowId); + } + + @Override + public ColumnarMap getMap(int rowId) { + throw new UnsupportedOperationException("Unsupported type - map"); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) { + return null; + } + return accessor.getDecimal(rowId, precision, scale); + } + + @Override + public UTF8String getUTF8String(int rowId) { + if (isNullAt(rowId)) { + return null; + } + return accessor.getUTF8String(rowId); + } + + @Override + public byte[] getBinary(int rowId) { + if (isNullAt(rowId)) { + return null; + } + return accessor.getBinary(rowId); + } + + @Override + public ArrowColumnVector getChild(int ordinal) { + return accessor.childColumn(ordinal); + } + + public ArrowVectorAccessor + vectorAccessor() { + return accessor; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/RowPositionColumnVector.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/RowPositionColumnVector.java new file mode 100644 index 000000000000..a389cd8286e5 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/RowPositionColumnVector.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.vectorized; + +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.unsafe.types.UTF8String; + +public class RowPositionColumnVector extends ColumnVector { + + private final long batchOffsetInFile; + + RowPositionColumnVector(long batchOffsetInFile) { + super(SparkSchemaUtil.convert(Types.LongType.get())); + this.batchOffsetInFile = batchOffsetInFile; + } + + @Override + public void close() {} + + @Override + public boolean hasNull() { + return false; + } + + @Override + public int numNulls() { + return 0; + } + + @Override + public boolean isNullAt(int rowId) { + return false; + } + + @Override + public boolean getBoolean(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public byte getByte(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public short getShort(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public int getInt(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public long getLong(int rowId) { + return batchOffsetInFile + rowId; + } + + @Override + public float getFloat(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public double getDouble(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarArray getArray(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarMap getMap(int ordinal) { + throw new UnsupportedOperationException(); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + throw new UnsupportedOperationException(); + } + + @Override + public UTF8String getUTF8String(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public byte[] getBinary(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnVector getChild(int ordinal) { + throw new UnsupportedOperationException(); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkOrcReaders.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkOrcReaders.java new file mode 100644 index 000000000000..c030311232a2 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkOrcReaders.java @@ -0,0 +1,497 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.vectorized; + +import java.util.List; +import java.util.Map; +import java.util.stream.IntStream; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Schema; +import org.apache.iceberg.orc.OrcBatchReader; +import org.apache.iceberg.orc.OrcSchemaWithTypeVisitor; +import org.apache.iceberg.orc.OrcValueReader; +import org.apache.iceberg.orc.OrcValueReaders; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.SparkOrcValueReaders; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.orc.TypeDescription; +import org.apache.orc.storage.ql.exec.vector.ListColumnVector; +import org.apache.orc.storage.ql.exec.vector.MapColumnVector; +import org.apache.orc.storage.ql.exec.vector.StructColumnVector; +import org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.unsafe.types.UTF8String; + +public class VectorizedSparkOrcReaders { + + private VectorizedSparkOrcReaders() {} + + public static OrcBatchReader buildReader( + Schema expectedSchema, TypeDescription fileSchema, Map idToConstant) { + Converter converter = + OrcSchemaWithTypeVisitor.visit(expectedSchema, fileSchema, new ReadBuilder(idToConstant)); + + return new OrcBatchReader() { + private long batchOffsetInFile; + + @Override + public ColumnarBatch read(VectorizedRowBatch batch) { + BaseOrcColumnVector cv = + (BaseOrcColumnVector) + converter.convert( + new StructColumnVector(batch.size, batch.cols), + batch.size, + batchOffsetInFile, + batch.selectedInUse, + batch.selected); + ColumnarBatch columnarBatch = + new ColumnarBatch( + IntStream.range(0, expectedSchema.columns().size()) + .mapToObj(cv::getChild) + .toArray(ColumnVector[]::new)); + columnarBatch.setNumRows(batch.size); + return columnarBatch; + } + + @Override + public void setBatchContext(long batchOffsetInFile) { + this.batchOffsetInFile = batchOffsetInFile; + } + }; + } + + private interface Converter { + ColumnVector convert( + org.apache.orc.storage.ql.exec.vector.ColumnVector columnVector, + int batchSize, + long batchOffsetInFile, + boolean isSelectedInUse, + int[] selected); + } + + private static class ReadBuilder extends OrcSchemaWithTypeVisitor { + private final Map idToConstant; + + private ReadBuilder(Map idToConstant) { + this.idToConstant = idToConstant; + } + + @Override + public Converter record( + Types.StructType iStruct, + TypeDescription record, + List names, + List fields) { + return new StructConverter(iStruct, fields, idToConstant); + } + + @Override + public Converter list(Types.ListType iList, TypeDescription array, Converter element) { + return new ArrayConverter(iList, element); + } + + @Override + public Converter map(Types.MapType iMap, TypeDescription map, Converter key, Converter value) { + return new MapConverter(iMap, key, value); + } + + @Override + public Converter primitive(Type.PrimitiveType iPrimitive, TypeDescription primitive) { + final OrcValueReader primitiveValueReader; + switch (primitive.getCategory()) { + case BOOLEAN: + primitiveValueReader = OrcValueReaders.booleans(); + break; + case BYTE: + // Iceberg does not have a byte type. Use int + case SHORT: + // Iceberg does not have a short type. Use int + case DATE: + case INT: + primitiveValueReader = OrcValueReaders.ints(); + break; + case LONG: + primitiveValueReader = OrcValueReaders.longs(); + break; + case FLOAT: + primitiveValueReader = OrcValueReaders.floats(); + break; + case DOUBLE: + primitiveValueReader = OrcValueReaders.doubles(); + break; + case TIMESTAMP_INSTANT: + case TIMESTAMP: + primitiveValueReader = SparkOrcValueReaders.timestampTzs(); + break; + case DECIMAL: + primitiveValueReader = + SparkOrcValueReaders.decimals(primitive.getPrecision(), primitive.getScale()); + break; + case CHAR: + case VARCHAR: + case STRING: + primitiveValueReader = SparkOrcValueReaders.utf8String(); + break; + case BINARY: + primitiveValueReader = + Type.TypeID.UUID == iPrimitive.typeId() + ? SparkOrcValueReaders.uuids() + : OrcValueReaders.bytes(); + break; + default: + throw new IllegalArgumentException("Unhandled type " + primitive); + } + return (columnVector, batchSize, batchOffsetInFile, isSelectedInUse, selected) -> + new PrimitiveOrcColumnVector( + iPrimitive, + batchSize, + columnVector, + primitiveValueReader, + batchOffsetInFile, + isSelectedInUse, + selected); + } + } + + private abstract static class BaseOrcColumnVector extends ColumnVector { + private final org.apache.orc.storage.ql.exec.vector.ColumnVector vector; + private final int batchSize; + private final boolean isSelectedInUse; + private final int[] selected; + private Integer numNulls; + + BaseOrcColumnVector( + Type type, + int batchSize, + org.apache.orc.storage.ql.exec.vector.ColumnVector vector, + boolean isSelectedInUse, + int[] selected) { + super(SparkSchemaUtil.convert(type)); + this.vector = vector; + this.batchSize = batchSize; + this.isSelectedInUse = isSelectedInUse; + this.selected = selected; + } + + @Override + public void close() {} + + @Override + public boolean hasNull() { + return !vector.noNulls; + } + + @Override + public int numNulls() { + if (numNulls == null) { + numNulls = numNullsHelper(); + } + return numNulls; + } + + private int numNullsHelper() { + if (vector.isRepeating) { + if (vector.isNull[0]) { + return batchSize; + } else { + return 0; + } + } else if (vector.noNulls) { + return 0; + } else { + int count = 0; + for (int i = 0; i < batchSize; i++) { + if (vector.isNull[i]) { + count++; + } + } + return count; + } + } + + protected int getRowIndex(int rowId) { + int row = isSelectedInUse ? selected[rowId] : rowId; + return vector.isRepeating ? 0 : row; + } + + @Override + public boolean isNullAt(int rowId) { + return vector.isNull[getRowIndex(rowId)]; + } + + @Override + public boolean getBoolean(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public byte getByte(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public short getShort(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public int getInt(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public long getLong(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public float getFloat(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public double getDouble(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + throw new UnsupportedOperationException(); + } + + @Override + public UTF8String getUTF8String(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public byte[] getBinary(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarArray getArray(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarMap getMap(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnVector getChild(int ordinal) { + throw new UnsupportedOperationException(); + } + } + + private static class PrimitiveOrcColumnVector extends BaseOrcColumnVector { + private final org.apache.orc.storage.ql.exec.vector.ColumnVector vector; + private final OrcValueReader primitiveValueReader; + private final long batchOffsetInFile; + + PrimitiveOrcColumnVector( + Type type, + int batchSize, + org.apache.orc.storage.ql.exec.vector.ColumnVector vector, + OrcValueReader primitiveValueReader, + long batchOffsetInFile, + boolean isSelectedInUse, + int[] selected) { + super(type, batchSize, vector, isSelectedInUse, selected); + this.vector = vector; + this.primitiveValueReader = primitiveValueReader; + this.batchOffsetInFile = batchOffsetInFile; + } + + @Override + public boolean getBoolean(int rowId) { + return (Boolean) primitiveValueReader.read(vector, getRowIndex(rowId)); + } + + @Override + public int getInt(int rowId) { + return (Integer) primitiveValueReader.read(vector, getRowIndex(rowId)); + } + + @Override + public long getLong(int rowId) { + return (Long) primitiveValueReader.read(vector, getRowIndex(rowId)); + } + + @Override + public float getFloat(int rowId) { + return (Float) primitiveValueReader.read(vector, getRowIndex(rowId)); + } + + @Override + public double getDouble(int rowId) { + return (Double) primitiveValueReader.read(vector, getRowIndex(rowId)); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + // TODO: Is it okay to assume that (precision,scale) parameters == (precision,scale) of the + // decimal type + // and return a Decimal with (precision,scale) of the decimal type? + return (Decimal) primitiveValueReader.read(vector, getRowIndex(rowId)); + } + + @Override + public UTF8String getUTF8String(int rowId) { + return (UTF8String) primitiveValueReader.read(vector, getRowIndex(rowId)); + } + + @Override + public byte[] getBinary(int rowId) { + return (byte[]) primitiveValueReader.read(vector, getRowIndex(rowId)); + } + } + + private static class ArrayConverter implements Converter { + private final Types.ListType listType; + private final Converter elementConverter; + + private ArrayConverter(Types.ListType listType, Converter elementConverter) { + this.listType = listType; + this.elementConverter = elementConverter; + } + + @Override + public ColumnVector convert( + org.apache.orc.storage.ql.exec.vector.ColumnVector vector, + int batchSize, + long batchOffsetInFile, + boolean isSelectedInUse, + int[] selected) { + ListColumnVector listVector = (ListColumnVector) vector; + ColumnVector elementVector = + elementConverter.convert(listVector.child, batchSize, batchOffsetInFile, false, null); + + return new BaseOrcColumnVector(listType, batchSize, vector, isSelectedInUse, selected) { + @Override + public ColumnarArray getArray(int rowId) { + int index = getRowIndex(rowId); + return new ColumnarArray( + elementVector, (int) listVector.offsets[index], (int) listVector.lengths[index]); + } + }; + } + } + + private static class MapConverter implements Converter { + private final Types.MapType mapType; + private final Converter keyConverter; + private final Converter valueConverter; + + private MapConverter(Types.MapType mapType, Converter keyConverter, Converter valueConverter) { + this.mapType = mapType; + this.keyConverter = keyConverter; + this.valueConverter = valueConverter; + } + + @Override + public ColumnVector convert( + org.apache.orc.storage.ql.exec.vector.ColumnVector vector, + int batchSize, + long batchOffsetInFile, + boolean isSelectedInUse, + int[] selected) { + MapColumnVector mapVector = (MapColumnVector) vector; + ColumnVector keyVector = + keyConverter.convert(mapVector.keys, batchSize, batchOffsetInFile, false, null); + ColumnVector valueVector = + valueConverter.convert(mapVector.values, batchSize, batchOffsetInFile, false, null); + + return new BaseOrcColumnVector(mapType, batchSize, vector, isSelectedInUse, selected) { + @Override + public ColumnarMap getMap(int rowId) { + int index = getRowIndex(rowId); + return new ColumnarMap( + keyVector, + valueVector, + (int) mapVector.offsets[index], + (int) mapVector.lengths[index]); + } + }; + } + } + + private static class StructConverter implements Converter { + private final Types.StructType structType; + private final List fieldConverters; + private final Map idToConstant; + + private StructConverter( + Types.StructType structType, + List fieldConverters, + Map idToConstant) { + this.structType = structType; + this.fieldConverters = fieldConverters; + this.idToConstant = idToConstant; + } + + @Override + public ColumnVector convert( + org.apache.orc.storage.ql.exec.vector.ColumnVector vector, + int batchSize, + long batchOffsetInFile, + boolean isSelectedInUse, + int[] selected) { + StructColumnVector structVector = (StructColumnVector) vector; + List fields = structType.fields(); + List fieldVectors = Lists.newArrayListWithExpectedSize(fields.size()); + for (int pos = 0, vectorIndex = 0; pos < fields.size(); pos += 1) { + Types.NestedField field = fields.get(pos); + if (idToConstant.containsKey(field.fieldId())) { + fieldVectors.add( + new ConstantColumnVector(field.type(), batchSize, idToConstant.get(field.fieldId()))); + } else if (field.equals(MetadataColumns.ROW_POSITION)) { + fieldVectors.add(new RowPositionColumnVector(batchOffsetInFile)); + } else if (field.equals(MetadataColumns.IS_DELETED)) { + fieldVectors.add(new ConstantColumnVector(field.type(), batchSize, false)); + } else { + fieldVectors.add( + fieldConverters + .get(vectorIndex) + .convert( + structVector.fields[vectorIndex], + batchSize, + batchOffsetInFile, + isSelectedInUse, + selected)); + vectorIndex++; + } + } + + return new BaseOrcColumnVector(structType, batchSize, vector, isSelectedInUse, selected) { + @Override + public ColumnVector getChild(int ordinal) { + return fieldVectors.get(ordinal); + } + }; + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkParquetReaders.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkParquetReaders.java new file mode 100644 index 000000000000..e47152c79398 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkParquetReaders.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.vectorized; + +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import org.apache.arrow.vector.NullCheckingForGet; +import org.apache.iceberg.Schema; +import org.apache.iceberg.arrow.vectorized.VectorizedReaderBuilder; +import org.apache.iceberg.data.DeleteFilter; +import org.apache.iceberg.parquet.TypeWithSchemaVisitor; +import org.apache.iceberg.parquet.VectorizedReader; +import org.apache.parquet.schema.MessageType; +import org.apache.spark.sql.catalyst.InternalRow; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class VectorizedSparkParquetReaders { + + private static final Logger LOG = LoggerFactory.getLogger(VectorizedSparkParquetReaders.class); + private static final String ENABLE_UNSAFE_MEMORY_ACCESS = "arrow.enable_unsafe_memory_access"; + private static final String ENABLE_UNSAFE_MEMORY_ACCESS_ENV = "ARROW_ENABLE_UNSAFE_MEMORY_ACCESS"; + private static final String ENABLE_NULL_CHECK_FOR_GET = "arrow.enable_null_check_for_get"; + private static final String ENABLE_NULL_CHECK_FOR_GET_ENV = "ARROW_ENABLE_NULL_CHECK_FOR_GET"; + + static { + try { + enableUnsafeMemoryAccess(); + disableNullCheckForGet(); + } catch (Exception e) { + LOG.warn("Couldn't set Arrow properties, which may impact read performance", e); + } + } + + private VectorizedSparkParquetReaders() {} + + public static ColumnarBatchReader buildReader( + Schema expectedSchema, + MessageType fileSchema, + Map idToConstant, + DeleteFilter deleteFilter) { + return (ColumnarBatchReader) + TypeWithSchemaVisitor.visit( + expectedSchema.asStruct(), + fileSchema, + new ReaderBuilder( + expectedSchema, + fileSchema, + NullCheckingForGet.NULL_CHECKING_ENABLED, + idToConstant, + ColumnarBatchReader::new, + deleteFilter)); + } + + // enables unsafe memory access to avoid costly checks to see if index is within bounds + // as long as it is not configured explicitly (see BoundsChecking in Arrow) + private static void enableUnsafeMemoryAccess() { + String value = confValue(ENABLE_UNSAFE_MEMORY_ACCESS, ENABLE_UNSAFE_MEMORY_ACCESS_ENV); + if (value == null) { + LOG.info("Enabling {}", ENABLE_UNSAFE_MEMORY_ACCESS); + System.setProperty(ENABLE_UNSAFE_MEMORY_ACCESS, "true"); + } else { + LOG.info("Unsafe memory access was configured explicitly: {}", value); + } + } + + // disables expensive null checks for every get call in favor of Iceberg nullability + // as long as it is not configured explicitly (see NullCheckingForGet in Arrow) + private static void disableNullCheckForGet() { + String value = confValue(ENABLE_NULL_CHECK_FOR_GET, ENABLE_NULL_CHECK_FOR_GET_ENV); + if (value == null) { + LOG.info("Disabling {}", ENABLE_NULL_CHECK_FOR_GET); + System.setProperty(ENABLE_NULL_CHECK_FOR_GET, "false"); + } else { + LOG.info("Null checking for get calls was configured explicitly: {}", value); + } + } + + private static String confValue(String propName, String envName) { + String propValue = System.getProperty(propName); + if (propValue != null) { + return propValue; + } + + return System.getenv(envName); + } + + private static class ReaderBuilder extends VectorizedReaderBuilder { + private final DeleteFilter deleteFilter; + + ReaderBuilder( + Schema expectedSchema, + MessageType parquetSchema, + boolean setArrowValidityVector, + Map idToConstant, + Function>, VectorizedReader> readerFactory, + DeleteFilter deleteFilter) { + super(expectedSchema, parquetSchema, setArrowValidityVector, idToConstant, readerFactory); + this.deleteFilter = deleteFilter; + } + + @Override + protected VectorizedReader vectorizedReader(List> reorderedFields) { + VectorizedReader reader = super.vectorizedReader(reorderedFields); + if (deleteFilter != null) { + ((ColumnarBatchReader) reader).setDeleteFilter(deleteFilter); + } + return reader; + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/functions/BaseScalarFunction.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/functions/BaseScalarFunction.java new file mode 100644 index 000000000000..5ec44f314180 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/functions/BaseScalarFunction.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.functions; + +import org.apache.spark.sql.connector.catalog.functions.ScalarFunction; + +abstract class BaseScalarFunction implements ScalarFunction { + @Override + public int hashCode() { + return canonicalName().hashCode(); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } else if (!(other instanceof ScalarFunction)) { + return false; + } + + ScalarFunction that = (ScalarFunction) other; + return canonicalName().equals(that.canonicalName()); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/functions/BucketFunction.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/functions/BucketFunction.java new file mode 100644 index 000000000000..c3de3d48dbcc --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/functions/BucketFunction.java @@ -0,0 +1,326 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.functions; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.util.Set; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.util.BucketUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.functions.BoundFunction; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.ByteType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.TimestampNTZType; +import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A Spark function implementation for the Iceberg bucket transform. + * + *

Example usage: {@code SELECT system.bucket(128, 'abc')}, which returns the bucket 122. + * + *

Note that for performance reasons, the given input number of buckets is not validated in the + * implementations used in code-gen. The number of buckets must be positive to give meaningful + * results. + */ +public class BucketFunction implements UnboundFunction { + + private static final int NUM_BUCKETS_ORDINAL = 0; + private static final int VALUE_ORDINAL = 1; + + private static final Set SUPPORTED_NUM_BUCKETS_TYPES = + ImmutableSet.of(DataTypes.ByteType, DataTypes.ShortType, DataTypes.IntegerType); + + @Override + @SuppressWarnings("checkstyle:CyclomaticComplexity") + public BoundFunction bind(StructType inputType) { + if (inputType.size() != 2) { + throw new UnsupportedOperationException( + "Wrong number of inputs (expected numBuckets and value)"); + } + + StructField numBucketsField = inputType.fields()[NUM_BUCKETS_ORDINAL]; + StructField valueField = inputType.fields()[VALUE_ORDINAL]; + + if (!SUPPORTED_NUM_BUCKETS_TYPES.contains(numBucketsField.dataType())) { + throw new UnsupportedOperationException( + "Expected number of buckets to be tinyint, shortint or int"); + } + + DataType type = valueField.dataType(); + if (type instanceof DateType) { + return new BucketInt(type); + } else if (type instanceof ByteType + || type instanceof ShortType + || type instanceof IntegerType) { + return new BucketInt(DataTypes.IntegerType); + } else if (type instanceof LongType) { + return new BucketLong(type); + } else if (type instanceof TimestampType) { + return new BucketLong(type); + } else if (type instanceof TimestampNTZType) { + return new BucketLong(type); + } else if (type instanceof DecimalType) { + return new BucketDecimal(type); + } else if (type instanceof StringType) { + return new BucketString(); + } else if (type instanceof BinaryType) { + return new BucketBinary(); + } else { + throw new UnsupportedOperationException( + "Expected column to be date, tinyint, smallint, int, bigint, decimal, timestamp, string, or binary"); + } + } + + @Override + public String description() { + return name() + + "(numBuckets, col) - Call Iceberg's bucket transform\n" + + " numBuckets :: number of buckets to divide the rows into, e.g. bucket(100, 34) -> 79 (must be a tinyint, smallint, or int)\n" + + " col :: column to bucket (must be a date, integer, long, timestamp, decimal, string, or binary)"; + } + + @Override + public String name() { + return "bucket"; + } + + public abstract static class BucketBase extends BaseScalarFunction { + public static int apply(int numBuckets, int hashedValue) { + return (hashedValue & Integer.MAX_VALUE) % numBuckets; + } + + @Override + public String name() { + return "bucket"; + } + + @Override + public DataType resultType() { + return DataTypes.IntegerType; + } + } + + // Used for both int and date - tinyint and smallint are upcasted to int by Spark. + public static class BucketInt extends BucketBase { + private final DataType sqlType; + + // magic method used in codegen + public static int invoke(int numBuckets, int value) { + return apply(numBuckets, hash(value)); + } + + // Visible for testing + public static int hash(int value) { + return BucketUtil.hash(value); + } + + public BucketInt(DataType sqlType) { + this.sqlType = sqlType; + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, sqlType}; + } + + @Override + public String canonicalName() { + return String.format("iceberg.bucket(%s)", sqlType.catalogString()); + } + + @Override + public Integer produceResult(InternalRow input) { + // return null for null input to match what Spark does in the code-generated versions. + if (input.isNullAt(NUM_BUCKETS_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) { + return null; + } else { + return invoke(input.getInt(NUM_BUCKETS_ORDINAL), input.getInt(VALUE_ORDINAL)); + } + } + } + + // Used for both BigInt and Timestamp + public static class BucketLong extends BucketBase { + private final DataType sqlType; + + // magic function for usage with codegen - needs to be static + public static int invoke(int numBuckets, long value) { + return apply(numBuckets, hash(value)); + } + + // Visible for testing + public static int hash(long value) { + return BucketUtil.hash(value); + } + + public BucketLong(DataType sqlType) { + this.sqlType = sqlType; + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, sqlType}; + } + + @Override + public String canonicalName() { + return String.format("iceberg.bucket(%s)", sqlType.catalogString()); + } + + @Override + public Integer produceResult(InternalRow input) { + if (input.isNullAt(NUM_BUCKETS_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) { + return null; + } else { + return invoke(input.getInt(NUM_BUCKETS_ORDINAL), input.getLong(VALUE_ORDINAL)); + } + } + } + + public static class BucketString extends BucketBase { + // magic function for usage with codegen + public static Integer invoke(int numBuckets, UTF8String value) { + if (value == null) { + return null; + } + + // TODO - We can probably hash the bytes directly given they're already UTF-8 input. + return apply(numBuckets, hash(value.toString())); + } + + // Visible for testing + public static int hash(String value) { + return BucketUtil.hash(value); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, DataTypes.StringType}; + } + + @Override + public String canonicalName() { + return "iceberg.bucket(string)"; + } + + @Override + public Integer produceResult(InternalRow input) { + if (input.isNullAt(NUM_BUCKETS_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) { + return null; + } else { + return invoke(input.getInt(NUM_BUCKETS_ORDINAL), input.getUTF8String(VALUE_ORDINAL)); + } + } + } + + public static class BucketBinary extends BucketBase { + public static Integer invoke(int numBuckets, byte[] value) { + if (value == null) { + return null; + } + + return apply(numBuckets, hash(ByteBuffer.wrap(value))); + } + + // Visible for testing + public static int hash(ByteBuffer value) { + return BucketUtil.hash(value); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, DataTypes.BinaryType}; + } + + @Override + public Integer produceResult(InternalRow input) { + if (input.isNullAt(NUM_BUCKETS_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) { + return null; + } else { + return invoke(input.getInt(NUM_BUCKETS_ORDINAL), input.getBinary(VALUE_ORDINAL)); + } + } + + @Override + public String canonicalName() { + return "iceberg.bucket(binary)"; + } + } + + public static class BucketDecimal extends BucketBase { + private final DataType sqlType; + private final int precision; + private final int scale; + + // magic method used in codegen + public static Integer invoke(int numBuckets, Decimal value) { + if (value == null) { + return null; + } + + return apply(numBuckets, hash(value.toJavaBigDecimal())); + } + + // Visible for testing + public static int hash(BigDecimal value) { + return BucketUtil.hash(value); + } + + public BucketDecimal(DataType sqlType) { + this.sqlType = sqlType; + this.precision = ((DecimalType) sqlType).precision(); + this.scale = ((DecimalType) sqlType).scale(); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, sqlType}; + } + + @Override + public Integer produceResult(InternalRow input) { + if (input.isNullAt(NUM_BUCKETS_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) { + return null; + } else { + int numBuckets = input.getInt(NUM_BUCKETS_ORDINAL); + Decimal value = input.getDecimal(VALUE_ORDINAL, precision, scale); + return invoke(numBuckets, value); + } + } + + @Override + public String canonicalName() { + return "iceberg.bucket(decimal)"; + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/functions/DaysFunction.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/functions/DaysFunction.java new file mode 100644 index 000000000000..f52edd9b208f --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/functions/DaysFunction.java @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.functions; + +import org.apache.iceberg.util.DateTimeUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.functions.BoundFunction; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.TimestampNTZType; +import org.apache.spark.sql.types.TimestampType; + +/** + * A Spark function implementation for the Iceberg day transform. + * + *

Example usage: {@code SELECT system.days('source_col')}. + */ +public class DaysFunction extends UnaryUnboundFunction { + + @Override + protected BoundFunction doBind(DataType valueType) { + if (valueType instanceof DateType) { + return new DateToDaysFunction(); + } else if (valueType instanceof TimestampType) { + return new TimestampToDaysFunction(); + } else if (valueType instanceof TimestampNTZType) { + return new TimestampNtzToDaysFunction(); + } else { + throw new UnsupportedOperationException( + "Expected value to be date or timestamp: " + valueType.catalogString()); + } + } + + @Override + public String description() { + return name() + + "(col) - Call Iceberg's day transform\n" + + " col :: source column (must be date or timestamp)"; + } + + @Override + public String name() { + return "days"; + } + + private abstract static class BaseToDaysFunction extends BaseScalarFunction { + @Override + public String name() { + return "days"; + } + + @Override + public DataType resultType() { + return DataTypes.DateType; + } + } + + // Spark and Iceberg internal representations of dates match so no transformation is required + public static class DateToDaysFunction extends BaseToDaysFunction { + // magic method used in codegen + public static int invoke(int days) { + return days; + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.DateType}; + } + + @Override + public String canonicalName() { + return "iceberg.days(date)"; + } + + @Override + public Integer produceResult(InternalRow input) { + // return null for null input to match what Spark does in codegen + return input.isNullAt(0) ? null : input.getInt(0); + } + } + + public static class TimestampToDaysFunction extends BaseToDaysFunction { + // magic method used in codegen + public static int invoke(long micros) { + return DateTimeUtil.microsToDays(micros); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.TimestampType}; + } + + @Override + public String canonicalName() { + return "iceberg.days(timestamp)"; + } + + @Override + public Integer produceResult(InternalRow input) { + // return null for null input to match what Spark does in codegen + return input.isNullAt(0) ? null : invoke(input.getLong(0)); + } + } + + public static class TimestampNtzToDaysFunction extends BaseToDaysFunction { + // magic method used in codegen + public static int invoke(long micros) { + return DateTimeUtil.microsToDays(micros); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.TimestampNTZType}; + } + + @Override + public String canonicalName() { + return "iceberg.days(timestamp_ntz)"; + } + + @Override + public Integer produceResult(InternalRow input) { + // return null for null input to match what Spark does in codegen + return input.isNullAt(0) ? null : invoke(input.getLong(0)); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/functions/HoursFunction.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/functions/HoursFunction.java new file mode 100644 index 000000000000..660a182f0e78 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/functions/HoursFunction.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.functions; + +import org.apache.iceberg.util.DateTimeUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.functions.BoundFunction; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.TimestampNTZType; +import org.apache.spark.sql.types.TimestampType; + +/** + * A Spark function implementation for the Iceberg hour transform. + * + *

Example usage: {@code SELECT system.hours('source_col')}. + */ +public class HoursFunction extends UnaryUnboundFunction { + + @Override + protected BoundFunction doBind(DataType valueType) { + if (valueType instanceof TimestampType) { + return new TimestampToHoursFunction(); + } else if (valueType instanceof TimestampNTZType) { + return new TimestampNtzToHoursFunction(); + } else { + throw new UnsupportedOperationException( + "Expected value to be timestamp: " + valueType.catalogString()); + } + } + + @Override + public String description() { + return name() + + "(col) - Call Iceberg's hour transform\n" + + " col :: source column (must be timestamp)"; + } + + @Override + public String name() { + return "hours"; + } + + public static class TimestampToHoursFunction extends BaseScalarFunction { + // magic method used in codegen + public static int invoke(long micros) { + return DateTimeUtil.microsToHours(micros); + } + + @Override + public String name() { + return "hours"; + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.TimestampType}; + } + + @Override + public DataType resultType() { + return DataTypes.IntegerType; + } + + @Override + public String canonicalName() { + return "iceberg.hours(timestamp)"; + } + + @Override + public Integer produceResult(InternalRow input) { + // return null for null input to match what Spark does in codegen + return input.isNullAt(0) ? null : invoke(input.getLong(0)); + } + } + + public static class TimestampNtzToHoursFunction extends BaseScalarFunction { + // magic method used in codegen + public static int invoke(long micros) { + return DateTimeUtil.microsToHours(micros); + } + + @Override + public String name() { + return "hours"; + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.TimestampNTZType}; + } + + @Override + public DataType resultType() { + return DataTypes.IntegerType; + } + + @Override + public String canonicalName() { + return "iceberg.hours(timestamp_ntz)"; + } + + @Override + public Integer produceResult(InternalRow input) { + // return null for null input to match what Spark does in codegen + return input.isNullAt(0) ? null : invoke(input.getLong(0)); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/functions/IcebergVersionFunction.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/functions/IcebergVersionFunction.java new file mode 100644 index 000000000000..689a0f4cb4df --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/functions/IcebergVersionFunction.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.functions; + +import org.apache.iceberg.IcebergBuild; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.functions.BoundFunction; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A function for use in SQL that returns the current Iceberg version, e.g. {@code SELECT + * system.iceberg_version()} will return a String such as "0.14.0" or "0.15.0-SNAPSHOT" + */ +public class IcebergVersionFunction implements UnboundFunction { + @Override + public BoundFunction bind(StructType inputType) { + if (inputType.fields().length > 0) { + throw new UnsupportedOperationException( + String.format("Cannot bind: %s does not accept arguments", name())); + } + + return new IcebergVersionFunctionImpl(); + } + + @Override + public String description() { + return name() + " - Returns the runtime Iceberg version"; + } + + @Override + public String name() { + return "iceberg_version"; + } + + // Implementing class cannot be private, otherwise Spark is unable to access the static invoke + // function during code-gen and calling the function fails + static class IcebergVersionFunctionImpl extends BaseScalarFunction { + private static final UTF8String VERSION = UTF8String.fromString(IcebergBuild.version()); + + // magic function used in code-gen. must be named `invoke`. + public static UTF8String invoke() { + return VERSION; + } + + @Override + public DataType[] inputTypes() { + return new DataType[0]; + } + + @Override + public DataType resultType() { + return DataTypes.StringType; + } + + @Override + public boolean isResultNullable() { + return false; + } + + @Override + public String canonicalName() { + return "iceberg." + name(); + } + + @Override + public String name() { + return "iceberg_version"; + } + + @Override + public UTF8String produceResult(InternalRow input) { + return invoke(); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/functions/MonthsFunction.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/functions/MonthsFunction.java new file mode 100644 index 000000000000..353d850f86e2 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/functions/MonthsFunction.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.functions; + +import org.apache.iceberg.util.DateTimeUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.functions.BoundFunction; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.TimestampNTZType; +import org.apache.spark.sql.types.TimestampType; + +/** + * A Spark function implementation for the Iceberg month transform. + * + *

Example usage: {@code SELECT system.months('source_col')}. + */ +public class MonthsFunction extends UnaryUnboundFunction { + + @Override + protected BoundFunction doBind(DataType valueType) { + if (valueType instanceof DateType) { + return new DateToMonthsFunction(); + } else if (valueType instanceof TimestampType) { + return new TimestampToMonthsFunction(); + } else if (valueType instanceof TimestampNTZType) { + return new TimestampNtzToMonthsFunction(); + } else { + throw new UnsupportedOperationException( + "Expected value to be date or timestamp: " + valueType.catalogString()); + } + } + + @Override + public String description() { + return name() + + "(col) - Call Iceberg's month transform\n" + + " col :: source column (must be date or timestamp)"; + } + + @Override + public String name() { + return "months"; + } + + private abstract static class BaseToMonthsFunction extends BaseScalarFunction { + @Override + public String name() { + return "months"; + } + + @Override + public DataType resultType() { + return DataTypes.IntegerType; + } + } + + public static class DateToMonthsFunction extends BaseToMonthsFunction { + // magic method used in codegen + public static int invoke(int days) { + return DateTimeUtil.daysToMonths(days); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.DateType}; + } + + @Override + public String canonicalName() { + return "iceberg.months(date)"; + } + + @Override + public Integer produceResult(InternalRow input) { + // return null for null input to match what Spark does in codegen + return input.isNullAt(0) ? null : invoke(input.getInt(0)); + } + } + + public static class TimestampToMonthsFunction extends BaseToMonthsFunction { + // magic method used in codegen + public static int invoke(long micros) { + return DateTimeUtil.microsToMonths(micros); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.TimestampType}; + } + + @Override + public String canonicalName() { + return "iceberg.months(timestamp)"; + } + + @Override + public Integer produceResult(InternalRow input) { + // return null for null input to match what Spark does in codegen + return input.isNullAt(0) ? null : invoke(input.getLong(0)); + } + } + + public static class TimestampNtzToMonthsFunction extends BaseToMonthsFunction { + // magic method used in codegen + public static int invoke(long micros) { + return DateTimeUtil.microsToMonths(micros); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.TimestampNTZType}; + } + + @Override + public String canonicalName() { + return "iceberg.months(timestamp_ntz)"; + } + + @Override + public Integer produceResult(InternalRow input) { + // return null for null input to match what Spark does in codegen + return input.isNullAt(0) ? null : invoke(input.getLong(0)); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/functions/SparkFunctions.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/functions/SparkFunctions.java new file mode 100644 index 000000000000..6d9cadec576d --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/functions/SparkFunctions.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.functions; + +import java.util.List; +import java.util.Locale; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; + +public class SparkFunctions { + + private SparkFunctions() {} + + private static final Map FUNCTIONS = + ImmutableMap.of( + "iceberg_version", new IcebergVersionFunction(), + "years", new YearsFunction(), + "months", new MonthsFunction(), + "days", new DaysFunction(), + "hours", new HoursFunction(), + "bucket", new BucketFunction(), + "truncate", new TruncateFunction()); + + private static final Map, UnboundFunction> CLASS_TO_FUNCTIONS = + ImmutableMap.of( + YearsFunction.class, new YearsFunction(), + MonthsFunction.class, new MonthsFunction(), + DaysFunction.class, new DaysFunction(), + HoursFunction.class, new HoursFunction(), + BucketFunction.class, new BucketFunction(), + TruncateFunction.class, new TruncateFunction()); + + private static final List FUNCTION_NAMES = ImmutableList.copyOf(FUNCTIONS.keySet()); + + // Functions that are added to all Iceberg catalogs should be accessed with the `system` + // namespace. They can also be accessed with no namespace at all if qualified with the + // catalog name, e.g. my_hadoop_catalog.iceberg_version(). + // As namespace resolution is handled by those rules in BaseCatalog, a list of names + // alone is returned. + public static List list() { + return FUNCTION_NAMES; + } + + public static UnboundFunction load(String name) { + // function resolution is case-insensitive to match the existing Spark behavior for functions + return FUNCTIONS.get(name.toLowerCase(Locale.ROOT)); + } + + public static UnboundFunction loadFunctionByClass(Class functionClass) { + Class declaringClass = functionClass.getDeclaringClass(); + if (declaringClass == null) { + return null; + } + + return CLASS_TO_FUNCTIONS.get(declaringClass); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/functions/TruncateFunction.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/functions/TruncateFunction.java new file mode 100644 index 000000000000..fac90c9efee6 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/functions/TruncateFunction.java @@ -0,0 +1,356 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.functions; + +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.util.Set; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.util.BinaryUtil; +import org.apache.iceberg.util.ByteBuffers; +import org.apache.iceberg.util.TruncateUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.functions.BoundFunction; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.ByteType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A Spark function implementation for the Iceberg truncate transform. + * + *

Example usage: {@code SELECT system.truncate(1, 'abc')}, which returns the String 'a'. + * + *

Note that for performance reasons, the given input width is not validated in the + * implementations used in code-gen. The width must remain non-negative to give meaningful results. + */ +public class TruncateFunction implements UnboundFunction { + + private static final int WIDTH_ORDINAL = 0; + private static final int VALUE_ORDINAL = 1; + + private static final Set SUPPORTED_WIDTH_TYPES = + ImmutableSet.of(DataTypes.ByteType, DataTypes.ShortType, DataTypes.IntegerType); + + @Override + public BoundFunction bind(StructType inputType) { + if (inputType.size() != 2) { + throw new UnsupportedOperationException("Wrong number of inputs (expected width and value)"); + } + + StructField widthField = inputType.fields()[WIDTH_ORDINAL]; + StructField valueField = inputType.fields()[VALUE_ORDINAL]; + + if (!SUPPORTED_WIDTH_TYPES.contains(widthField.dataType())) { + throw new UnsupportedOperationException( + "Expected truncation width to be tinyint, shortint or int"); + } + + DataType valueType = valueField.dataType(); + if (valueType instanceof ByteType) { + return new TruncateTinyInt(); + } else if (valueType instanceof ShortType) { + return new TruncateSmallInt(); + } else if (valueType instanceof IntegerType) { + return new TruncateInt(); + } else if (valueType instanceof LongType) { + return new TruncateBigInt(); + } else if (valueType instanceof DecimalType) { + return new TruncateDecimal( + ((DecimalType) valueType).precision(), ((DecimalType) valueType).scale()); + } else if (valueType instanceof StringType) { + return new TruncateString(); + } else if (valueType instanceof BinaryType) { + return new TruncateBinary(); + } else { + throw new UnsupportedOperationException( + "Expected truncation col to be tinyint, shortint, int, bigint, decimal, string, or binary"); + } + } + + @Override + public String description() { + return name() + + "(width, col) - Call Iceberg's truncate transform\n" + + " width :: width for truncation, e.g. truncate(10, 255) -> 250 (must be an integer)\n" + + " col :: column to truncate (must be an integer, decimal, string, or binary)"; + } + + @Override + public String name() { + return "truncate"; + } + + public abstract static class TruncateBase extends BaseScalarFunction { + @Override + public String name() { + return "truncate"; + } + } + + public static class TruncateTinyInt extends TruncateBase { + public static byte invoke(int width, byte value) { + return TruncateUtil.truncateByte(width, value); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, DataTypes.ByteType}; + } + + @Override + public DataType resultType() { + return DataTypes.ByteType; + } + + @Override + public String canonicalName() { + return "iceberg.truncate(tinyint)"; + } + + @Override + public Byte produceResult(InternalRow input) { + if (input.isNullAt(WIDTH_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) { + return null; + } else { + return invoke(input.getInt(WIDTH_ORDINAL), input.getByte(VALUE_ORDINAL)); + } + } + } + + public static class TruncateSmallInt extends TruncateBase { + // magic method used in codegen + public static short invoke(int width, short value) { + return TruncateUtil.truncateShort(width, value); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, DataTypes.ShortType}; + } + + @Override + public DataType resultType() { + return DataTypes.ShortType; + } + + @Override + public String canonicalName() { + return "iceberg.truncate(smallint)"; + } + + @Override + public Short produceResult(InternalRow input) { + if (input.isNullAt(WIDTH_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) { + return null; + } else { + return invoke(input.getInt(WIDTH_ORDINAL), input.getShort(VALUE_ORDINAL)); + } + } + } + + public static class TruncateInt extends TruncateBase { + // magic method used in codegen + public static int invoke(int width, int value) { + return TruncateUtil.truncateInt(width, value); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, DataTypes.IntegerType}; + } + + @Override + public DataType resultType() { + return DataTypes.IntegerType; + } + + @Override + public String canonicalName() { + return "iceberg.truncate(int)"; + } + + @Override + public Integer produceResult(InternalRow input) { + if (input.isNullAt(WIDTH_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) { + return null; + } else { + return invoke(input.getInt(WIDTH_ORDINAL), input.getInt(VALUE_ORDINAL)); + } + } + } + + public static class TruncateBigInt extends TruncateBase { + // magic function for usage with codegen + public static long invoke(int width, long value) { + return TruncateUtil.truncateLong(width, value); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, DataTypes.LongType}; + } + + @Override + public DataType resultType() { + return DataTypes.LongType; + } + + @Override + public String canonicalName() { + return "iceberg.truncate(bigint)"; + } + + @Override + public Long produceResult(InternalRow input) { + if (input.isNullAt(WIDTH_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) { + return null; + } else { + return invoke(input.getInt(WIDTH_ORDINAL), input.getLong(VALUE_ORDINAL)); + } + } + } + + public static class TruncateString extends TruncateBase { + // magic function for usage with codegen + public static UTF8String invoke(int width, UTF8String value) { + if (value == null) { + return null; + } + + return value.substring(0, width); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, DataTypes.StringType}; + } + + @Override + public DataType resultType() { + return DataTypes.StringType; + } + + @Override + public String canonicalName() { + return "iceberg.truncate(string)"; + } + + @Override + public UTF8String produceResult(InternalRow input) { + if (input.isNullAt(WIDTH_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) { + return null; + } else { + return invoke(input.getInt(WIDTH_ORDINAL), input.getUTF8String(VALUE_ORDINAL)); + } + } + } + + public static class TruncateBinary extends TruncateBase { + // magic method used in codegen + public static byte[] invoke(int width, byte[] value) { + if (value == null) { + return null; + } + + return ByteBuffers.toByteArray( + BinaryUtil.truncateBinaryUnsafe(ByteBuffer.wrap(value), width)); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, DataTypes.BinaryType}; + } + + @Override + public DataType resultType() { + return DataTypes.BinaryType; + } + + @Override + public String canonicalName() { + return "iceberg.truncate(binary)"; + } + + @Override + public byte[] produceResult(InternalRow input) { + if (input.isNullAt(WIDTH_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) { + return null; + } else { + return invoke(input.getInt(WIDTH_ORDINAL), input.getBinary(VALUE_ORDINAL)); + } + } + } + + public static class TruncateDecimal extends TruncateBase { + private final int precision; + private final int scale; + + public TruncateDecimal(int precision, int scale) { + this.precision = precision; + this.scale = scale; + } + + // magic method used in codegen + public static Decimal invoke(int width, Decimal value) { + if (value == null) { + return null; + } + + return Decimal.apply( + TruncateUtil.truncateDecimal(BigInteger.valueOf(width), value.toJavaBigDecimal())); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, DataTypes.createDecimalType(precision, scale)}; + } + + @Override + public DataType resultType() { + return DataTypes.createDecimalType(precision, scale); + } + + @Override + public String canonicalName() { + return String.format("iceberg.truncate(decimal(%d,%d))", precision, scale); + } + + @Override + public Decimal produceResult(InternalRow input) { + if (input.isNullAt(WIDTH_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) { + return null; + } else { + int width = input.getInt(WIDTH_ORDINAL); + Decimal value = input.getDecimal(VALUE_ORDINAL, precision, scale); + return invoke(width, value); + } + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/functions/UnaryUnboundFunction.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/functions/UnaryUnboundFunction.java new file mode 100644 index 000000000000..9003c68919dc --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/functions/UnaryUnboundFunction.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.functions; + +import org.apache.spark.sql.connector.catalog.functions.BoundFunction; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.StructType; + +/** An unbound function that accepts only one argument */ +abstract class UnaryUnboundFunction implements UnboundFunction { + + @Override + public BoundFunction bind(StructType inputType) { + DataType valueType = valueType(inputType); + return doBind(valueType); + } + + protected abstract BoundFunction doBind(DataType valueType); + + private DataType valueType(StructType inputType) { + if (inputType.size() != 1) { + throw new UnsupportedOperationException("Wrong number of inputs (expected value)"); + } + + return inputType.fields()[0].dataType(); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/functions/YearsFunction.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/functions/YearsFunction.java new file mode 100644 index 000000000000..cfd1b0e8d002 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/functions/YearsFunction.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.functions; + +import org.apache.iceberg.util.DateTimeUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.functions.BoundFunction; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.TimestampNTZType; +import org.apache.spark.sql.types.TimestampType; + +/** + * A Spark function implementation for the Iceberg year transform. + * + *

Example usage: {@code SELECT system.years('source_col')}. + */ +public class YearsFunction extends UnaryUnboundFunction { + + @Override + protected BoundFunction doBind(DataType valueType) { + if (valueType instanceof DateType) { + return new DateToYearsFunction(); + } else if (valueType instanceof TimestampType) { + return new TimestampToYearsFunction(); + } else if (valueType instanceof TimestampNTZType) { + return new TimestampNtzToYearsFunction(); + } else { + throw new UnsupportedOperationException( + "Expected value to be date or timestamp: " + valueType.catalogString()); + } + } + + @Override + public String description() { + return name() + + "(col) - Call Iceberg's year transform\n" + + " col :: source column (must be date or timestamp)"; + } + + @Override + public String name() { + return "years"; + } + + private abstract static class BaseToYearsFunction extends BaseScalarFunction { + @Override + public String name() { + return "years"; + } + + @Override + public DataType resultType() { + return DataTypes.IntegerType; + } + } + + public static class DateToYearsFunction extends BaseToYearsFunction { + // magic method used in codegen + public static int invoke(int days) { + return DateTimeUtil.daysToYears(days); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.DateType}; + } + + @Override + public String canonicalName() { + return "iceberg.years(date)"; + } + + @Override + public Integer produceResult(InternalRow input) { + // return null for null input to match what Spark does in codegen + return input.isNullAt(0) ? null : invoke(input.getInt(0)); + } + } + + public static class TimestampToYearsFunction extends BaseToYearsFunction { + // magic method used in codegen + public static int invoke(long micros) { + return DateTimeUtil.microsToYears(micros); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.TimestampType}; + } + + @Override + public String canonicalName() { + return "iceberg.years(timestamp)"; + } + + @Override + public Integer produceResult(InternalRow input) { + // return null for null input to match what Spark does in codegen + return input.isNullAt(0) ? null : invoke(input.getLong(0)); + } + } + + public static class TimestampNtzToYearsFunction extends BaseToYearsFunction { + // magic method used in codegen + public static int invoke(long micros) { + return DateTimeUtil.microsToYears(micros); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.TimestampNTZType}; + } + + @Override + public String canonicalName() { + return "iceberg.years(timestamp_ntz)"; + } + + @Override + public Integer produceResult(InternalRow input) { + // return null for null input to match what Spark does in codegen + return input.isNullAt(0) ? null : invoke(input.getLong(0)); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/AddFilesProcedure.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/AddFilesProcedure.java new file mode 100644 index 000000000000..40a343b55b80 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/AddFilesProcedure.java @@ -0,0 +1,319 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.PartitionField; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.mapping.MappingUtil; +import org.apache.iceberg.mapping.NameMapping; +import org.apache.iceberg.mapping.NameMappingParser; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkTableUtil; +import org.apache.iceberg.spark.SparkTableUtil.SparkPartition; +import org.apache.iceberg.util.LocationUtil; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.TableIdentifier; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +class AddFilesProcedure extends BaseProcedure { + + private static final ProcedureParameter TABLE_PARAM = + ProcedureParameter.required("table", DataTypes.StringType); + private static final ProcedureParameter SOURCE_TABLE_PARAM = + ProcedureParameter.required("source_table", DataTypes.StringType); + private static final ProcedureParameter PARTITION_FILTER_PARAM = + ProcedureParameter.optional("partition_filter", STRING_MAP); + private static final ProcedureParameter CHECK_DUPLICATE_FILES_PARAM = + ProcedureParameter.optional("check_duplicate_files", DataTypes.BooleanType); + + private static final ProcedureParameter PARALLELISM = + ProcedureParameter.optional("parallelism", DataTypes.IntegerType); + + private static final ProcedureParameter[] PARAMETERS = + new ProcedureParameter[] { + TABLE_PARAM, + SOURCE_TABLE_PARAM, + PARTITION_FILTER_PARAM, + CHECK_DUPLICATE_FILES_PARAM, + PARALLELISM + }; + + private static final StructType OUTPUT_TYPE = + new StructType( + new StructField[] { + new StructField("added_files_count", DataTypes.LongType, false, Metadata.empty()), + new StructField("changed_partition_count", DataTypes.LongType, true, Metadata.empty()), + }); + + private AddFilesProcedure(TableCatalog tableCatalog) { + super(tableCatalog); + } + + public static SparkProcedures.ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + protected AddFilesProcedure doBuild() { + return new AddFilesProcedure(tableCatalog()); + } + }; + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + ProcedureInput input = new ProcedureInput(spark(), tableCatalog(), PARAMETERS, args); + + Identifier tableIdent = input.ident(TABLE_PARAM); + + CatalogPlugin sessionCat = spark().sessionState().catalogManager().v2SessionCatalog(); + Identifier sourceIdent = input.ident(SOURCE_TABLE_PARAM, sessionCat); + + Map partitionFilter = + input.asStringMap(PARTITION_FILTER_PARAM, ImmutableMap.of()); + + boolean checkDuplicateFiles = input.asBoolean(CHECK_DUPLICATE_FILES_PARAM, true); + + int parallelism = input.asInt(PARALLELISM, 1); + + return importToIceberg( + tableIdent, sourceIdent, partitionFilter, checkDuplicateFiles, parallelism); + } + + private InternalRow[] toOutputRows(Snapshot snapshot) { + Map summary = snapshot.summary(); + return new InternalRow[] { + newInternalRow(addedFilesCount(summary), changedPartitionCount(summary)) + }; + } + + private long addedFilesCount(Map stats) { + return PropertyUtil.propertyAsLong(stats, SnapshotSummary.ADDED_FILES_PROP, 0L); + } + + private Long changedPartitionCount(Map stats) { + return PropertyUtil.propertyAsNullableLong(stats, SnapshotSummary.CHANGED_PARTITION_COUNT_PROP); + } + + private boolean isFileIdentifier(Identifier ident) { + String[] namespace = ident.namespace(); + return namespace.length == 1 + && (namespace[0].equalsIgnoreCase("orc") + || namespace[0].equalsIgnoreCase("parquet") + || namespace[0].equalsIgnoreCase("avro")); + } + + private InternalRow[] importToIceberg( + Identifier destIdent, + Identifier sourceIdent, + Map partitionFilter, + boolean checkDuplicateFiles, + int parallelism) { + return modifyIcebergTable( + destIdent, + table -> { + validatePartitionSpec(table, partitionFilter); + ensureNameMappingPresent(table); + + if (isFileIdentifier(sourceIdent)) { + Path sourcePath = new Path(sourceIdent.name()); + String format = sourceIdent.namespace()[0]; + importFileTable( + table, + sourcePath, + format, + partitionFilter, + checkDuplicateFiles, + table.spec(), + parallelism); + } else { + importCatalogTable( + table, sourceIdent, partitionFilter, checkDuplicateFiles, parallelism); + } + + Snapshot snapshot = table.currentSnapshot(); + return toOutputRows(snapshot); + }); + } + + private static void ensureNameMappingPresent(Table table) { + if (table.properties().get(TableProperties.DEFAULT_NAME_MAPPING) == null) { + // Forces Name based resolution instead of position based resolution + NameMapping mapping = MappingUtil.create(table.schema()); + String mappingJson = NameMappingParser.toJson(mapping); + table.updateProperties().set(TableProperties.DEFAULT_NAME_MAPPING, mappingJson).commit(); + } + } + + private void importFileTable( + Table table, + Path tableLocation, + String format, + Map partitionFilter, + boolean checkDuplicateFiles, + PartitionSpec spec, + int parallelism) { + // List Partitions via Spark InMemory file search interface + List partitions = + Spark3Util.getPartitions(spark(), tableLocation, format, partitionFilter, spec); + + if (table.spec().isUnpartitioned()) { + Preconditions.checkArgument( + partitions.isEmpty(), "Cannot add partitioned files to an unpartitioned table"); + Preconditions.checkArgument( + partitionFilter.isEmpty(), + "Cannot use a partition filter when importing" + "to an unpartitioned table"); + + // Build a Global Partition for the source + SparkPartition partition = + new SparkPartition(Collections.emptyMap(), tableLocation.toString(), format); + importPartitions(table, ImmutableList.of(partition), checkDuplicateFiles, parallelism); + } else { + Preconditions.checkArgument( + !partitions.isEmpty(), "Cannot find any matching partitions in table %s", table.name()); + importPartitions(table, partitions, checkDuplicateFiles, parallelism); + } + } + + private void importCatalogTable( + Table table, + Identifier sourceIdent, + Map partitionFilter, + boolean checkDuplicateFiles, + int parallelism) { + String stagingLocation = getMetadataLocation(table); + TableIdentifier sourceTableIdentifier = Spark3Util.toV1TableIdentifier(sourceIdent); + SparkTableUtil.importSparkTable( + spark(), + sourceTableIdentifier, + table, + stagingLocation, + partitionFilter, + checkDuplicateFiles, + parallelism); + } + + private void importPartitions( + Table table, + List partitions, + boolean checkDuplicateFiles, + int parallelism) { + String stagingLocation = getMetadataLocation(table); + SparkTableUtil.importSparkPartitions( + spark(), + partitions, + table, + table.spec(), + stagingLocation, + checkDuplicateFiles, + parallelism); + } + + private String getMetadataLocation(Table table) { + String defaultValue = LocationUtil.stripTrailingSlash(table.location()) + "/metadata"; + return LocationUtil.stripTrailingSlash( + table.properties().getOrDefault(TableProperties.WRITE_METADATA_LOCATION, defaultValue)); + } + + @Override + public String description() { + return "AddFiles"; + } + + private void validatePartitionSpec(Table table, Map partitionFilter) { + List partitionFields = table.spec().fields(); + Set partitionNames = + table.spec().fields().stream().map(PartitionField::name).collect(Collectors.toSet()); + + boolean tablePartitioned = !partitionFields.isEmpty(); + boolean partitionSpecPassed = !partitionFilter.isEmpty(); + + // Check for any non-identity partition columns + List nonIdentityFields = + partitionFields.stream() + .filter(x -> !x.transform().isIdentity()) + .collect(Collectors.toList()); + Preconditions.checkArgument( + nonIdentityFields.isEmpty(), + "Cannot add data files to target table %s because that table is partitioned and contains non-identity" + + "partition transforms which will not be compatible. Found non-identity fields %s", + table.name(), + nonIdentityFields); + + if (tablePartitioned && partitionSpecPassed) { + // Check to see there are sufficient partition columns to satisfy the filter + Preconditions.checkArgument( + partitionFields.size() >= partitionFilter.size(), + "Cannot add data files to target table %s because that table is partitioned, " + + "but the number of columns in the provided partition filter (%s) " + + "is greater than the number of partitioned columns in table (%s)", + table.name(), + partitionFilter.size(), + partitionFields.size()); + + // Check for any filters of non existent columns + List unMatchedFilters = + partitionFilter.keySet().stream() + .filter(filterName -> !partitionNames.contains(filterName)) + .collect(Collectors.toList()); + Preconditions.checkArgument( + unMatchedFilters.isEmpty(), + "Cannot add files to target table %s. %s is partitioned but the specified partition filter " + + "refers to columns that are not partitioned: '%s' . Valid partition columns %s", + table.name(), + table.name(), + unMatchedFilters, + String.join(",", partitionNames)); + } else { + Preconditions.checkArgument( + !partitionSpecPassed, + "Cannot use partition filter with an unpartitioned table %s", + table.name()); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/AncestorsOfProcedure.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/AncestorsOfProcedure.java new file mode 100644 index 000000000000..c3a6ca138358 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/AncestorsOfProcedure.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import java.util.List; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.source.SparkTable; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +public class AncestorsOfProcedure extends BaseProcedure { + + private static final ProcedureParameter TABLE_PARAM = + ProcedureParameter.required("table", DataTypes.StringType); + private static final ProcedureParameter SNAPSHOT_ID_PARAM = + ProcedureParameter.optional("snapshot_id", DataTypes.LongType); + + private static final ProcedureParameter[] PARAMETERS = + new ProcedureParameter[] {TABLE_PARAM, SNAPSHOT_ID_PARAM}; + + private static final StructType OUTPUT_TYPE = + new StructType( + new StructField[] { + new StructField("snapshot_id", DataTypes.LongType, true, Metadata.empty()), + new StructField("timestamp", DataTypes.LongType, true, Metadata.empty()) + }); + + private AncestorsOfProcedure(TableCatalog tableCatalog) { + super(tableCatalog); + } + + public static SparkProcedures.ProcedureBuilder builder() { + return new Builder() { + @Override + protected AncestorsOfProcedure doBuild() { + return new AncestorsOfProcedure(tableCatalog()); + } + }; + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + ProcedureInput input = new ProcedureInput(spark(), tableCatalog(), PARAMETERS, args); + + Identifier tableIdent = input.ident(TABLE_PARAM); + Long toSnapshotId = input.asLong(SNAPSHOT_ID_PARAM, null); + + SparkTable sparkTable = loadSparkTable(tableIdent); + Table icebergTable = sparkTable.table(); + + if (toSnapshotId == null) { + toSnapshotId = + icebergTable.currentSnapshot() != null ? icebergTable.currentSnapshot().snapshotId() : -1; + } + + List snapshotIds = + Lists.newArrayList( + SnapshotUtil.ancestorIdsBetween(toSnapshotId, null, icebergTable::snapshot)); + + return toOutputRow(icebergTable, snapshotIds); + } + + @Override + public String description() { + return "AncestorsOf"; + } + + private InternalRow[] toOutputRow(Table table, List snapshotIds) { + if (snapshotIds.isEmpty()) { + return new InternalRow[0]; + } + + InternalRow[] internalRows = new InternalRow[snapshotIds.size()]; + for (int i = 0; i < snapshotIds.size(); i++) { + Long snapshotId = snapshotIds.get(i); + internalRows[i] = newInternalRow(snapshotId, table.snapshot(snapshotId).timestampMillis()); + } + + return internalRows; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/BaseProcedure.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/BaseProcedure.java new file mode 100644 index 000000000000..fb8bdc252df5 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/BaseProcedure.java @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.function.Function; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors; +import org.apache.iceberg.relocated.com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.Spark3Util.CatalogAndIdentifier; +import org.apache.iceberg.spark.actions.SparkActions; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.iceberg.spark.source.SparkTable; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.Procedure; +import org.apache.spark.sql.execution.CacheManager; +import org.apache.spark.sql.execution.datasources.SparkExpressionConverter; +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import scala.Option; + +abstract class BaseProcedure implements Procedure { + protected static final DataType STRING_MAP = + DataTypes.createMapType(DataTypes.StringType, DataTypes.StringType); + protected static final DataType STRING_ARRAY = DataTypes.createArrayType(DataTypes.StringType); + + private final SparkSession spark; + private final TableCatalog tableCatalog; + + private SparkActions actions; + private ExecutorService executorService = null; + + protected BaseProcedure(TableCatalog tableCatalog) { + this.spark = SparkSession.active(); + this.tableCatalog = tableCatalog; + } + + protected SparkSession spark() { + return this.spark; + } + + protected SparkActions actions() { + if (actions == null) { + this.actions = SparkActions.get(spark); + } + return actions; + } + + protected TableCatalog tableCatalog() { + return this.tableCatalog; + } + + protected T modifyIcebergTable(Identifier ident, Function func) { + try { + return execute(ident, true, func); + } finally { + closeService(); + } + } + + protected T withIcebergTable(Identifier ident, Function func) { + try { + return execute(ident, false, func); + } finally { + closeService(); + } + } + + private T execute( + Identifier ident, boolean refreshSparkCache, Function func) { + SparkTable sparkTable = loadSparkTable(ident); + org.apache.iceberg.Table icebergTable = sparkTable.table(); + + T result = func.apply(icebergTable); + + if (refreshSparkCache) { + refreshSparkCache(ident, sparkTable); + } + + return result; + } + + protected Identifier toIdentifier(String identifierAsString, String argName) { + CatalogAndIdentifier catalogAndIdentifier = + toCatalogAndIdentifier(identifierAsString, argName, tableCatalog); + + Preconditions.checkArgument( + catalogAndIdentifier.catalog().equals(tableCatalog), + "Cannot run procedure in catalog '%s': '%s' is a table in catalog '%s'", + tableCatalog.name(), + identifierAsString, + catalogAndIdentifier.catalog().name()); + + return catalogAndIdentifier.identifier(); + } + + protected CatalogAndIdentifier toCatalogAndIdentifier( + String identifierAsString, String argName, CatalogPlugin catalog) { + Preconditions.checkArgument( + identifierAsString != null && !identifierAsString.isEmpty(), + "Cannot handle an empty identifier for argument %s", + argName); + + return Spark3Util.catalogAndIdentifier( + "identifier for arg " + argName, spark, identifierAsString, catalog); + } + + protected SparkTable loadSparkTable(Identifier ident) { + try { + Table table = tableCatalog.loadTable(ident); + ValidationException.check( + table instanceof SparkTable, "%s is not %s", ident, SparkTable.class.getName()); + return (SparkTable) table; + } catch (NoSuchTableException e) { + String errMsg = + String.format("Couldn't load table '%s' in catalog '%s'", ident, tableCatalog.name()); + throw new RuntimeException(errMsg, e); + } + } + + protected Dataset loadRows(Identifier tableIdent, Map options) { + String tableName = Spark3Util.quotedFullIdentifier(tableCatalog().name(), tableIdent); + return spark().read().options(options).table(tableName); + } + + protected void refreshSparkCache(Identifier ident, Table table) { + CacheManager cacheManager = spark.sharedState().cacheManager(); + DataSourceV2Relation relation = + DataSourceV2Relation.create(table, Option.apply(tableCatalog), Option.apply(ident)); + cacheManager.recacheByPlan(spark, relation); + } + + protected Expression filterExpression(Identifier ident, String where) { + try { + String name = Spark3Util.quotedFullIdentifier(tableCatalog.name(), ident); + org.apache.spark.sql.catalyst.expressions.Expression expression = + SparkExpressionConverter.collectResolvedSparkExpression(spark, name, where); + return SparkExpressionConverter.convertToIcebergExpression(expression); + } catch (AnalysisException e) { + throw new IllegalArgumentException("Cannot parse predicates in where option: " + where, e); + } + } + + protected InternalRow newInternalRow(Object... values) { + return new GenericInternalRow(values); + } + + protected abstract static class Builder implements ProcedureBuilder { + private TableCatalog tableCatalog; + + @Override + public Builder withTableCatalog(TableCatalog newTableCatalog) { + this.tableCatalog = newTableCatalog; + return this; + } + + @Override + public T build() { + return doBuild(); + } + + protected abstract T doBuild(); + + TableCatalog tableCatalog() { + return tableCatalog; + } + } + + /** + * Closes this procedure's executor service if a new one was created with {@link + * BaseProcedure#executorService(int, String)}. Does not block for any remaining tasks. + */ + protected void closeService() { + if (executorService != null) { + executorService.shutdown(); + } + } + + /** + * Starts a new executor service which can be used by this procedure in its work. The pool will be + * automatically shut down if {@link #withIcebergTable(Identifier, Function)} or {@link + * #modifyIcebergTable(Identifier, Function)} are called. If these methods are not used then the + * service can be shut down with {@link #closeService()} or left to be closed when this class is + * finalized. + * + * @param threadPoolSize number of threads in the service + * @param nameFormat name prefix for threads created in this service + * @return the new executor service owned by this procedure + */ + protected ExecutorService executorService(int threadPoolSize, String nameFormat) { + Preconditions.checkArgument( + executorService == null, "Cannot create a new executor service, one already exists."); + Preconditions.checkArgument( + nameFormat != null, "Cannot create a service with null nameFormat arg"); + this.executorService = + MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) + Executors.newFixedThreadPool( + threadPoolSize, + new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat(nameFormat + "-%d") + .build())); + + return executorService; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/CherrypickSnapshotProcedure.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/CherrypickSnapshotProcedure.java new file mode 100644 index 000000000000..efe9aeb9e8e8 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/CherrypickSnapshotProcedure.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * A procedure that applies changes in a given snapshot and creates a new snapshot which will be set + * as the current snapshot in a table. + * + *

Note: this procedure invalidates all cached Spark plans that reference the affected + * table. + * + * @see org.apache.iceberg.ManageSnapshots#cherrypick(long) + */ +class CherrypickSnapshotProcedure extends BaseProcedure { + + private static final ProcedureParameter[] PARAMETERS = + new ProcedureParameter[] { + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.required("snapshot_id", DataTypes.LongType) + }; + + private static final StructType OUTPUT_TYPE = + new StructType( + new StructField[] { + new StructField("source_snapshot_id", DataTypes.LongType, false, Metadata.empty()), + new StructField("current_snapshot_id", DataTypes.LongType, false, Metadata.empty()) + }); + + public static ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + protected CherrypickSnapshotProcedure doBuild() { + return new CherrypickSnapshotProcedure(tableCatalog()); + } + }; + } + + private CherrypickSnapshotProcedure(TableCatalog catalog) { + super(catalog); + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); + long snapshotId = args.getLong(1); + + return modifyIcebergTable( + tableIdent, + table -> { + table.manageSnapshots().cherrypick(snapshotId).commit(); + + Snapshot currentSnapshot = table.currentSnapshot(); + + InternalRow outputRow = newInternalRow(snapshotId, currentSnapshot.snapshotId()); + return new InternalRow[] {outputRow}; + }); + } + + @Override + public String description() { + return "CherrypickSnapshotProcedure"; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/CreateChangelogViewProcedure.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/CreateChangelogViewProcedure.java new file mode 100644 index 000000000000..ae77b69133f3 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/CreateChangelogViewProcedure.java @@ -0,0 +1,296 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.ChangelogIterator; +import org.apache.iceberg.spark.source.SparkChangelogTable; +import org.apache.spark.api.java.function.MapPartitionsFunction; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.OrderUtils; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A procedure that creates a view for changed rows. + * + *

The procedure always removes the carry-over rows. Please query {@link SparkChangelogTable} + * instead when carry-over rows are required. + * + *

The procedure doesn't compute the pre/post update images by default. If you want to compute + * them, you can set "compute_updates" to be true in the options. + * + *

Carry-over rows are the result of a removal and insertion of the same row within an operation + * because of the copy-on-write mechanism. For example, given a file which contains row1 (id=1, + * data='a') and row2 (id=2, data='b'). A copy-on-write delete of row2 would require erasing this + * file and preserving row1 in a new file. The changelog table would report this as (id=1, data='a', + * op='DELETE') and (id=1, data='a', op='INSERT'), despite it not being an actual change to the + * table. The procedure finds the carry-over rows and removes them from the result. + * + *

Pre/post update images are converted from a pair of a delete row and an insert row. Identifier + * columns are used for determining whether an insert and a delete record refer to the same row. If + * the two records share the same values for the identity columns they are considered to be before + * and after states of the same row. You can either set identifier fields in the table schema or + * input them as the procedure parameters. Here is an example of pre/post update images with an + * identifier column(id). A pair of a delete row and an insert row with the same id: + * + *

    + *
  • (id=1, data='a', op='DELETE') + *
  • (id=1, data='b', op='INSERT') + *
+ * + *

will be marked as pre/post update images: + * + *

    + *
  • (id=1, data='a', op='UPDATE_BEFORE') + *
  • (id=1, data='b', op='UPDATE_AFTER') + *
+ */ +public class CreateChangelogViewProcedure extends BaseProcedure { + + private static final ProcedureParameter TABLE_PARAM = + ProcedureParameter.required("table", DataTypes.StringType); + private static final ProcedureParameter CHANGELOG_VIEW_PARAM = + ProcedureParameter.optional("changelog_view", DataTypes.StringType); + private static final ProcedureParameter OPTIONS_PARAM = + ProcedureParameter.optional("options", STRING_MAP); + private static final ProcedureParameter COMPUTE_UPDATES_PARAM = + ProcedureParameter.optional("compute_updates", DataTypes.BooleanType); + private static final ProcedureParameter IDENTIFIER_COLUMNS_PARAM = + ProcedureParameter.optional("identifier_columns", STRING_ARRAY); + private static final ProcedureParameter NET_CHANGES = + ProcedureParameter.optional("net_changes", DataTypes.BooleanType); + + private static final ProcedureParameter[] PARAMETERS = + new ProcedureParameter[] { + TABLE_PARAM, + CHANGELOG_VIEW_PARAM, + OPTIONS_PARAM, + COMPUTE_UPDATES_PARAM, + IDENTIFIER_COLUMNS_PARAM, + NET_CHANGES, + }; + + private static final StructType OUTPUT_TYPE = + new StructType( + new StructField[] { + new StructField("changelog_view", DataTypes.StringType, false, Metadata.empty()) + }); + + public static SparkProcedures.ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + protected CreateChangelogViewProcedure doBuild() { + return new CreateChangelogViewProcedure(tableCatalog()); + } + }; + } + + private CreateChangelogViewProcedure(TableCatalog tableCatalog) { + super(tableCatalog); + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + ProcedureInput input = new ProcedureInput(spark(), tableCatalog(), PARAMETERS, args); + + Identifier tableIdent = input.ident(TABLE_PARAM); + + // load insert and deletes from the changelog table + Identifier changelogTableIdent = changelogTableIdent(tableIdent); + Dataset df = loadRows(changelogTableIdent, options(input)); + + boolean netChanges = input.asBoolean(NET_CHANGES, false); + String[] identifierColumns = identifierColumns(input, tableIdent); + Set unorderableColumnNames = + Arrays.stream(df.schema().fields()) + .filter(field -> !OrderUtils.isOrderable(field.dataType())) + .map(StructField::name) + .collect(Collectors.toSet()); + + Preconditions.checkArgument( + identifierColumns.length > 0 || unorderableColumnNames.isEmpty(), + "Identifier field is required as table contains unorderable columns: %s", + unorderableColumnNames); + + if (shouldComputeUpdateImages(input)) { + Preconditions.checkArgument(!netChanges, "Not support net changes with update images"); + df = computeUpdateImages(identifierColumns, df); + } else { + df = removeCarryoverRows(df, netChanges); + } + + String viewName = viewName(input, tableIdent.name()); + + df.createOrReplaceTempView(viewName); + + return toOutputRows(viewName); + } + + private Dataset computeUpdateImages(String[] identifierColumns, Dataset df) { + Preconditions.checkArgument( + identifierColumns.length > 0, + "Cannot compute the update images because identifier columns are not set"); + + Column[] repartitionSpec = new Column[identifierColumns.length + 1]; + for (int i = 0; i < identifierColumns.length; i++) { + repartitionSpec[i] = df.col(identifierColumns[i]); + } + + repartitionSpec[repartitionSpec.length - 1] = df.col(MetadataColumns.CHANGE_ORDINAL.name()); + + return applyChangelogIterator(df, repartitionSpec); + } + + private boolean shouldComputeUpdateImages(ProcedureInput input) { + // If the identifier columns are set, we compute pre/post update images by default. + boolean defaultValue = input.isProvided(IDENTIFIER_COLUMNS_PARAM); + return input.asBoolean(COMPUTE_UPDATES_PARAM, defaultValue); + } + + private Dataset removeCarryoverRows(Dataset df, boolean netChanges) { + Predicate columnsToKeep; + if (netChanges) { + Set metadataColumn = + Sets.newHashSet( + MetadataColumns.CHANGE_TYPE.name(), + MetadataColumns.CHANGE_ORDINAL.name(), + MetadataColumns.COMMIT_SNAPSHOT_ID.name()); + + columnsToKeep = column -> !metadataColumn.contains(column); + } else { + columnsToKeep = column -> !column.equals(MetadataColumns.CHANGE_TYPE.name()); + } + + Column[] repartitionSpec = + Arrays.stream(df.columns()).filter(columnsToKeep).map(df::col).toArray(Column[]::new); + return applyCarryoverRemoveIterator(df, repartitionSpec, netChanges); + } + + private String[] identifierColumns(ProcedureInput input, Identifier tableIdent) { + if (input.isProvided(IDENTIFIER_COLUMNS_PARAM)) { + return input.asStringArray(IDENTIFIER_COLUMNS_PARAM); + } else { + Table table = loadSparkTable(tableIdent).table(); + return table.schema().identifierFieldNames().toArray(new String[0]); + } + } + + private Identifier changelogTableIdent(Identifier tableIdent) { + List namespace = Lists.newArrayList(); + namespace.addAll(Arrays.asList(tableIdent.namespace())); + namespace.add(tableIdent.name()); + return Identifier.of(namespace.toArray(new String[0]), SparkChangelogTable.TABLE_NAME); + } + + private Map options(ProcedureInput input) { + return input.asStringMap(OPTIONS_PARAM, ImmutableMap.of()); + } + + private String viewName(ProcedureInput input, String tableName) { + String defaultValue = String.format("`%s_changes`", tableName); + return input.asString(CHANGELOG_VIEW_PARAM, defaultValue); + } + + private Dataset applyChangelogIterator(Dataset df, Column[] repartitionSpec) { + Column[] sortSpec = sortSpec(df, repartitionSpec, false); + StructType schema = df.schema(); + String[] identifierFields = + Arrays.stream(repartitionSpec).map(Column::toString).toArray(String[]::new); + + return df.repartition(repartitionSpec) + .sortWithinPartitions(sortSpec) + .mapPartitions( + (MapPartitionsFunction) + rowIterator -> + ChangelogIterator.computeUpdates(rowIterator, schema, identifierFields), + Encoders.row(schema)); + } + + private Dataset applyCarryoverRemoveIterator( + Dataset df, Column[] repartitionSpec, boolean netChanges) { + Column[] sortSpec = sortSpec(df, repartitionSpec, netChanges); + StructType schema = df.schema(); + + return df.repartition(repartitionSpec) + .sortWithinPartitions(sortSpec) + .mapPartitions( + (MapPartitionsFunction) + rowIterator -> + netChanges + ? ChangelogIterator.removeNetCarryovers(rowIterator, schema) + : ChangelogIterator.removeCarryovers(rowIterator, schema), + Encoders.row(schema)); + } + + private static Column[] sortSpec(Dataset df, Column[] repartitionSpec, boolean netChanges) { + Column changeType = df.col(MetadataColumns.CHANGE_TYPE.name()); + Column changeOrdinal = df.col(MetadataColumns.CHANGE_ORDINAL.name()); + Column[] extraColumns = + netChanges ? new Column[] {changeOrdinal, changeType} : new Column[] {changeType}; + + Column[] sortSpec = new Column[repartitionSpec.length + extraColumns.length]; + + System.arraycopy(repartitionSpec, 0, sortSpec, 0, repartitionSpec.length); + System.arraycopy(extraColumns, 0, sortSpec, repartitionSpec.length, extraColumns.length); + + return sortSpec; + } + + private InternalRow[] toOutputRows(String viewName) { + InternalRow row = newInternalRow(UTF8String.fromString(viewName)); + return new InternalRow[] {row}; + } + + @Override + public String description() { + return "CreateChangelogViewProcedure"; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/ExpireSnapshotsProcedure.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/ExpireSnapshotsProcedure.java new file mode 100644 index 000000000000..b84d69ea9c1d --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/ExpireSnapshotsProcedure.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.ExpireSnapshots; +import org.apache.iceberg.io.SupportsBulkOperations; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.actions.ExpireSnapshotsSparkAction; +import org.apache.iceberg.spark.actions.SparkActions; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.iceberg.util.DateTimeUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A procedure that expires snapshots in a table. + * + * @see SparkActions#expireSnapshots(Table) + */ +public class ExpireSnapshotsProcedure extends BaseProcedure { + + private static final Logger LOG = LoggerFactory.getLogger(ExpireSnapshotsProcedure.class); + + private static final ProcedureParameter[] PARAMETERS = + new ProcedureParameter[] { + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.optional("older_than", DataTypes.TimestampType), + ProcedureParameter.optional("retain_last", DataTypes.IntegerType), + ProcedureParameter.optional("max_concurrent_deletes", DataTypes.IntegerType), + ProcedureParameter.optional("stream_results", DataTypes.BooleanType), + ProcedureParameter.optional("snapshot_ids", DataTypes.createArrayType(DataTypes.LongType)) + }; + + private static final StructType OUTPUT_TYPE = + new StructType( + new StructField[] { + new StructField("deleted_data_files_count", DataTypes.LongType, true, Metadata.empty()), + new StructField( + "deleted_position_delete_files_count", DataTypes.LongType, true, Metadata.empty()), + new StructField( + "deleted_equality_delete_files_count", DataTypes.LongType, true, Metadata.empty()), + new StructField( + "deleted_manifest_files_count", DataTypes.LongType, true, Metadata.empty()), + new StructField( + "deleted_manifest_lists_count", DataTypes.LongType, true, Metadata.empty()), + new StructField( + "deleted_statistics_files_count", DataTypes.LongType, true, Metadata.empty()) + }); + + public static ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + protected ExpireSnapshotsProcedure doBuild() { + return new ExpireSnapshotsProcedure(tableCatalog()); + } + }; + } + + private ExpireSnapshotsProcedure(TableCatalog tableCatalog) { + super(tableCatalog); + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + @SuppressWarnings("checkstyle:CyclomaticComplexity") + public InternalRow[] call(InternalRow args) { + Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); + Long olderThanMillis = args.isNullAt(1) ? null : DateTimeUtil.microsToMillis(args.getLong(1)); + Integer retainLastNum = args.isNullAt(2) ? null : args.getInt(2); + Integer maxConcurrentDeletes = args.isNullAt(3) ? null : args.getInt(3); + Boolean streamResult = args.isNullAt(4) ? null : args.getBoolean(4); + long[] snapshotIds = args.isNullAt(5) ? null : args.getArray(5).toLongArray(); + + Preconditions.checkArgument( + maxConcurrentDeletes == null || maxConcurrentDeletes > 0, + "max_concurrent_deletes should have value > 0, value: %s", + maxConcurrentDeletes); + + return modifyIcebergTable( + tableIdent, + table -> { + ExpireSnapshots action = actions().expireSnapshots(table); + + if (olderThanMillis != null) { + action.expireOlderThan(olderThanMillis); + } + + if (retainLastNum != null) { + action.retainLast(retainLastNum); + } + + if (maxConcurrentDeletes != null) { + if (table.io() instanceof SupportsBulkOperations) { + LOG.warn( + "max_concurrent_deletes only works with FileIOs that do not support bulk deletes. This " + + "table is currently using {} which supports bulk deletes so the parameter will be ignored. " + + "See that IO's documentation to learn how to adjust parallelism for that particular " + + "IO's bulk delete.", + table.io().getClass().getName()); + } else { + + action.executeDeleteWith(executorService(maxConcurrentDeletes, "expire-snapshots")); + } + } + + if (snapshotIds != null) { + for (long snapshotId : snapshotIds) { + action.expireSnapshotId(snapshotId); + } + } + + if (streamResult != null) { + action.option( + ExpireSnapshotsSparkAction.STREAM_RESULTS, Boolean.toString(streamResult)); + } + + ExpireSnapshots.Result result = action.execute(); + + return toOutputRows(result); + }); + } + + private InternalRow[] toOutputRows(ExpireSnapshots.Result result) { + InternalRow row = + newInternalRow( + result.deletedDataFilesCount(), + result.deletedPositionDeleteFilesCount(), + result.deletedEqualityDeleteFilesCount(), + result.deletedManifestsCount(), + result.deletedManifestListsCount(), + result.deletedStatisticsFilesCount()); + return new InternalRow[] {row}; + } + + @Override + public String description() { + return "ExpireSnapshotProcedure"; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/FastForwardBranchProcedure.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/FastForwardBranchProcedure.java new file mode 100644 index 000000000000..11ea5d44c9f8 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/FastForwardBranchProcedure.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; + +public class FastForwardBranchProcedure extends BaseProcedure { + + private static final ProcedureParameter[] PARAMETERS = + new ProcedureParameter[] { + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.required("branch", DataTypes.StringType), + ProcedureParameter.required("to", DataTypes.StringType) + }; + + private static final StructType OUTPUT_TYPE = + new StructType( + new StructField[] { + new StructField("branch_updated", DataTypes.StringType, false, Metadata.empty()), + new StructField("previous_ref", DataTypes.LongType, true, Metadata.empty()), + new StructField("updated_ref", DataTypes.LongType, false, Metadata.empty()) + }); + + public static SparkProcedures.ProcedureBuilder builder() { + return new Builder() { + @Override + protected FastForwardBranchProcedure doBuild() { + return new FastForwardBranchProcedure(tableCatalog()); + } + }; + } + + private FastForwardBranchProcedure(TableCatalog tableCatalog) { + super(tableCatalog); + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); + String from = args.getString(1); + String to = args.getString(2); + + return modifyIcebergTable( + tableIdent, + table -> { + Long snapshotBefore = + table.snapshot(from) != null ? table.snapshot(from).snapshotId() : null; + table.manageSnapshots().fastForwardBranch(from, to).commit(); + long snapshotAfter = table.snapshot(from).snapshotId(); + InternalRow outputRow = + newInternalRow(UTF8String.fromString(from), snapshotBefore, snapshotAfter); + return new InternalRow[] {outputRow}; + }); + } + + @Override + public String description() { + return "FastForwardBranchProcedure"; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/MigrateTableProcedure.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/MigrateTableProcedure.java new file mode 100644 index 000000000000..a0bd04dd997e --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/MigrateTableProcedure.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import java.util.Map; +import org.apache.iceberg.actions.MigrateTable; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.actions.MigrateTableSparkAction; +import org.apache.iceberg.spark.actions.SparkActions; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import scala.runtime.BoxedUnit; + +class MigrateTableProcedure extends BaseProcedure { + private static final ProcedureParameter[] PARAMETERS = + new ProcedureParameter[] { + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.optional("properties", STRING_MAP), + ProcedureParameter.optional("drop_backup", DataTypes.BooleanType), + ProcedureParameter.optional("backup_table_name", DataTypes.StringType), + ProcedureParameter.optional("parallelism", DataTypes.IntegerType) + }; + + private static final StructType OUTPUT_TYPE = + new StructType( + new StructField[] { + new StructField("migrated_files_count", DataTypes.LongType, false, Metadata.empty()) + }); + + private MigrateTableProcedure(TableCatalog tableCatalog) { + super(tableCatalog); + } + + public static ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + protected MigrateTableProcedure doBuild() { + return new MigrateTableProcedure(tableCatalog()); + } + }; + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + String tableName = args.getString(0); + Preconditions.checkArgument( + tableName != null && !tableName.isEmpty(), + "Cannot handle an empty identifier for argument table"); + + Map properties = Maps.newHashMap(); + if (!args.isNullAt(1)) { + args.getMap(1) + .foreach( + DataTypes.StringType, + DataTypes.StringType, + (k, v) -> { + properties.put(k.toString(), v.toString()); + return BoxedUnit.UNIT; + }); + } + + boolean dropBackup = args.isNullAt(2) ? false : args.getBoolean(2); + String backupTableName = args.isNullAt(3) ? null : args.getString(3); + + MigrateTableSparkAction migrateTableSparkAction = + SparkActions.get().migrateTable(tableName).tableProperties(properties); + + if (dropBackup) { + migrateTableSparkAction = migrateTableSparkAction.dropBackup(); + } + + if (backupTableName != null) { + migrateTableSparkAction = migrateTableSparkAction.backupTableName(backupTableName); + } + + if (!args.isNullAt(4)) { + int parallelism = args.getInt(4); + Preconditions.checkArgument(parallelism > 0, "Parallelism should be larger than 0"); + migrateTableSparkAction = + migrateTableSparkAction.executeWith(executorService(parallelism, "table-migration")); + } + + MigrateTable.Result result = migrateTableSparkAction.execute(); + return new InternalRow[] {newInternalRow(result.migratedDataFilesCount())}; + } + + @Override + public String description() { + return "MigrateTableProcedure"; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/ProcedureInput.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/ProcedureInput.java new file mode 100644 index 000000000000..0be4b38de79c --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/ProcedureInput.java @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import java.lang.reflect.Array; +import java.util.Map; +import java.util.function.BiFunction; +import org.apache.commons.lang3.StringUtils; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.Spark3Util.CatalogAndIdentifier; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; + +/** A class that abstracts common logic for working with input to a procedure. */ +class ProcedureInput { + + private static final DataType STRING_ARRAY = DataTypes.createArrayType(DataTypes.StringType); + private static final DataType STRING_MAP = + DataTypes.createMapType(DataTypes.StringType, DataTypes.StringType); + + private final SparkSession spark; + private final TableCatalog catalog; + private final Map paramOrdinals; + private final InternalRow args; + + ProcedureInput( + SparkSession spark, TableCatalog catalog, ProcedureParameter[] params, InternalRow args) { + this.spark = spark; + this.catalog = catalog; + this.paramOrdinals = computeParamOrdinals(params); + this.args = args; + } + + public boolean isProvided(ProcedureParameter param) { + int ordinal = ordinal(param); + return !args.isNullAt(ordinal); + } + + public Boolean asBoolean(ProcedureParameter param, Boolean defaultValue) { + validateParamType(param, DataTypes.BooleanType); + int ordinal = ordinal(param); + return args.isNullAt(ordinal) ? defaultValue : (Boolean) args.getBoolean(ordinal); + } + + public Integer asInt(ProcedureParameter param) { + Integer value = asInt(param, null); + Preconditions.checkArgument(value != null, "Parameter '%s' is not set", param.name()); + return value; + } + + public Integer asInt(ProcedureParameter param, Integer defaultValue) { + validateParamType(param, DataTypes.IntegerType); + int ordinal = ordinal(param); + return args.isNullAt(ordinal) ? defaultValue : (Integer) args.getInt(ordinal); + } + + public long asLong(ProcedureParameter param) { + Long value = asLong(param, null); + Preconditions.checkArgument(value != null, "Parameter '%s' is not set", param.name()); + return value; + } + + public Long asLong(ProcedureParameter param, Long defaultValue) { + validateParamType(param, DataTypes.LongType); + int ordinal = ordinal(param); + return args.isNullAt(ordinal) ? defaultValue : (Long) args.getLong(ordinal); + } + + public String asString(ProcedureParameter param) { + String value = asString(param, null); + Preconditions.checkArgument(value != null, "Parameter '%s' is not set", param.name()); + return value; + } + + public String asString(ProcedureParameter param, String defaultValue) { + validateParamType(param, DataTypes.StringType); + int ordinal = ordinal(param); + return args.isNullAt(ordinal) ? defaultValue : args.getString(ordinal); + } + + public String[] asStringArray(ProcedureParameter param) { + String[] value = asStringArray(param, null); + Preconditions.checkArgument(value != null, "Parameter '%s' is not set", param.name()); + return value; + } + + public String[] asStringArray(ProcedureParameter param, String[] defaultValue) { + validateParamType(param, STRING_ARRAY); + return array( + param, + (array, ordinal) -> array.getUTF8String(ordinal).toString(), + String.class, + defaultValue); + } + + @SuppressWarnings("unchecked") + private T[] array( + ProcedureParameter param, + BiFunction convertElement, + Class elementClass, + T[] defaultValue) { + + int ordinal = ordinal(param); + + if (args.isNullAt(ordinal)) { + return defaultValue; + } + + ArrayData arrayData = args.getArray(ordinal); + + T[] convertedArray = (T[]) Array.newInstance(elementClass, arrayData.numElements()); + + for (int index = 0; index < arrayData.numElements(); index++) { + convertedArray[index] = convertElement.apply(arrayData, index); + } + + return convertedArray; + } + + public Map asStringMap( + ProcedureParameter param, Map defaultValue) { + validateParamType(param, STRING_MAP); + return map( + param, + (keys, ordinal) -> keys.getUTF8String(ordinal).toString(), + (values, ordinal) -> values.getUTF8String(ordinal).toString(), + defaultValue); + } + + private Map map( + ProcedureParameter param, + BiFunction convertKey, + BiFunction convertValue, + Map defaultValue) { + + int ordinal = ordinal(param); + + if (args.isNullAt(ordinal)) { + return defaultValue; + } + + MapData mapData = args.getMap(ordinal); + + Map convertedMap = Maps.newHashMap(); + + for (int index = 0; index < mapData.numElements(); index++) { + K convertedKey = convertKey.apply(mapData.keyArray(), index); + V convertedValue = convertValue.apply(mapData.valueArray(), index); + convertedMap.put(convertedKey, convertedValue); + } + + return convertedMap; + } + + public Identifier ident(ProcedureParameter param) { + CatalogAndIdentifier catalogAndIdent = catalogAndIdent(param, catalog); + + Preconditions.checkArgument( + catalogAndIdent.catalog().equals(catalog), + "Cannot run procedure in catalog '%s': '%s' is a table in catalog '%s'", + catalog.name(), + catalogAndIdent.identifier(), + catalogAndIdent.catalog().name()); + + return catalogAndIdent.identifier(); + } + + public Identifier ident(ProcedureParameter param, CatalogPlugin defaultCatalog) { + CatalogAndIdentifier catalogAndIdent = catalogAndIdent(param, defaultCatalog); + return catalogAndIdent.identifier(); + } + + private CatalogAndIdentifier catalogAndIdent( + ProcedureParameter param, CatalogPlugin defaultCatalog) { + + String identAsString = asString(param); + + Preconditions.checkArgument( + StringUtils.isNotBlank(identAsString), + "Cannot handle an empty identifier for parameter '%s'", + param.name()); + + String desc = String.format("identifier for parameter '%s'", param.name()); + return Spark3Util.catalogAndIdentifier(desc, spark, identAsString, defaultCatalog); + } + + private int ordinal(ProcedureParameter param) { + return paramOrdinals.get(param.name()); + } + + private Map computeParamOrdinals(ProcedureParameter[] params) { + Map ordinals = Maps.newHashMap(); + + for (int index = 0; index < params.length; index++) { + String paramName = params[index].name(); + + Preconditions.checkArgument( + !ordinals.containsKey(paramName), + "Detected multiple parameters named as '%s'", + paramName); + + ordinals.put(paramName, index); + } + + return ordinals; + } + + private void validateParamType(ProcedureParameter param, DataType expectedDataType) { + Preconditions.checkArgument( + expectedDataType.sameType(param.dataType()), + "Parameter '%s' must be of type %s", + param.name(), + expectedDataType.catalogString()); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/PublishChangesProcedure.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/PublishChangesProcedure.java new file mode 100644 index 000000000000..eb6c762ed51e --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/PublishChangesProcedure.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import java.util.Optional; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.iceberg.util.WapUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * A procedure that applies changes in a snapshot created within a Write-Audit-Publish workflow with + * a wap_id and creates a new snapshot which will be set as the current snapshot in a table. + * + *

Note: this procedure invalidates all cached Spark plans that reference the affected + * table. + * + * @see org.apache.iceberg.ManageSnapshots#cherrypick(long) + */ +class PublishChangesProcedure extends BaseProcedure { + + private static final ProcedureParameter[] PARAMETERS = + new ProcedureParameter[] { + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.required("wap_id", DataTypes.StringType) + }; + + private static final StructType OUTPUT_TYPE = + new StructType( + new StructField[] { + new StructField("source_snapshot_id", DataTypes.LongType, false, Metadata.empty()), + new StructField("current_snapshot_id", DataTypes.LongType, false, Metadata.empty()) + }); + + public static ProcedureBuilder builder() { + return new Builder() { + @Override + protected PublishChangesProcedure doBuild() { + return new PublishChangesProcedure(tableCatalog()); + } + }; + } + + private PublishChangesProcedure(TableCatalog catalog) { + super(catalog); + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); + String wapId = args.getString(1); + + return modifyIcebergTable( + tableIdent, + table -> { + Optional wapSnapshot = + Optional.ofNullable( + Iterables.find( + table.snapshots(), + snapshot -> wapId.equals(WapUtil.stagedWapId(snapshot)), + null)); + if (!wapSnapshot.isPresent()) { + throw new ValidationException(String.format("Cannot apply unknown WAP ID '%s'", wapId)); + } + + long wapSnapshotId = wapSnapshot.get().snapshotId(); + table.manageSnapshots().cherrypick(wapSnapshotId).commit(); + + Snapshot currentSnapshot = table.currentSnapshot(); + + InternalRow outputRow = newInternalRow(wapSnapshotId, currentSnapshot.snapshotId()); + return new InternalRow[] {outputRow}; + }); + } + + @Override + public String description() { + return "ApplyWapChangesProcedure"; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/RegisterTableProcedure.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/RegisterTableProcedure.java new file mode 100644 index 000000000000..857949e052c8 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/RegisterTableProcedure.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.Table; +import org.apache.iceberg.catalog.Catalog; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.iceberg.spark.source.HasIcebergCatalog; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +class RegisterTableProcedure extends BaseProcedure { + private static final ProcedureParameter[] PARAMETERS = + new ProcedureParameter[] { + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.required("metadata_file", DataTypes.StringType) + }; + + private static final StructType OUTPUT_TYPE = + new StructType( + new StructField[] { + new StructField("current_snapshot_id", DataTypes.LongType, true, Metadata.empty()), + new StructField("total_records_count", DataTypes.LongType, true, Metadata.empty()), + new StructField("total_data_files_count", DataTypes.LongType, true, Metadata.empty()) + }); + + private RegisterTableProcedure(TableCatalog tableCatalog) { + super(tableCatalog); + } + + public static ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + protected RegisterTableProcedure doBuild() { + return new RegisterTableProcedure(tableCatalog()); + } + }; + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + TableIdentifier tableName = + Spark3Util.identifierToTableIdentifier(toIdentifier(args.getString(0), "table")); + String metadataFile = args.getString(1); + Preconditions.checkArgument( + tableCatalog() instanceof HasIcebergCatalog, + "Cannot use Register Table in a non-Iceberg catalog"); + Preconditions.checkArgument( + metadataFile != null && !metadataFile.isEmpty(), + "Cannot handle an empty argument metadata_file"); + + Catalog icebergCatalog = ((HasIcebergCatalog) tableCatalog()).icebergCatalog(); + Table table = icebergCatalog.registerTable(tableName, metadataFile); + Long currentSnapshotId = null; + Long totalDataFiles = null; + Long totalRecords = null; + + Snapshot currentSnapshot = table.currentSnapshot(); + if (currentSnapshot != null) { + currentSnapshotId = currentSnapshot.snapshotId(); + totalDataFiles = + Long.parseLong(currentSnapshot.summary().get(SnapshotSummary.TOTAL_DATA_FILES_PROP)); + totalRecords = + Long.parseLong(currentSnapshot.summary().get(SnapshotSummary.TOTAL_RECORDS_PROP)); + } + + return new InternalRow[] {newInternalRow(currentSnapshotId, totalRecords, totalDataFiles)}; + } + + @Override + public String description() { + return "RegisterTableProcedure"; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/RemoveOrphanFilesProcedure.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/RemoveOrphanFilesProcedure.java new file mode 100644 index 000000000000..6609efa95eb1 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/RemoveOrphanFilesProcedure.java @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import java.util.Map; +import java.util.concurrent.TimeUnit; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.DeleteOrphanFiles; +import org.apache.iceberg.actions.DeleteOrphanFiles.PrefixMismatchMode; +import org.apache.iceberg.io.SupportsBulkOperations; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.actions.DeleteOrphanFilesSparkAction; +import org.apache.iceberg.spark.actions.SparkActions; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.iceberg.util.DateTimeUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.runtime.BoxedUnit; + +/** + * A procedure that removes orphan files in a table. + * + * @see SparkActions#deleteOrphanFiles(Table) + */ +public class RemoveOrphanFilesProcedure extends BaseProcedure { + private static final Logger LOG = LoggerFactory.getLogger(RemoveOrphanFilesProcedure.class); + + private static final ProcedureParameter[] PARAMETERS = + new ProcedureParameter[] { + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.optional("older_than", DataTypes.TimestampType), + ProcedureParameter.optional("location", DataTypes.StringType), + ProcedureParameter.optional("dry_run", DataTypes.BooleanType), + ProcedureParameter.optional("max_concurrent_deletes", DataTypes.IntegerType), + ProcedureParameter.optional("file_list_view", DataTypes.StringType), + ProcedureParameter.optional("equal_schemes", STRING_MAP), + ProcedureParameter.optional("equal_authorities", STRING_MAP), + ProcedureParameter.optional("prefix_mismatch_mode", DataTypes.StringType), + }; + + private static final StructType OUTPUT_TYPE = + new StructType( + new StructField[] { + new StructField("orphan_file_location", DataTypes.StringType, false, Metadata.empty()) + }); + + public static ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + protected RemoveOrphanFilesProcedure doBuild() { + return new RemoveOrphanFilesProcedure(tableCatalog()); + } + }; + } + + private RemoveOrphanFilesProcedure(TableCatalog catalog) { + super(catalog); + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + @SuppressWarnings("checkstyle:CyclomaticComplexity") + public InternalRow[] call(InternalRow args) { + Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); + Long olderThanMillis = args.isNullAt(1) ? null : DateTimeUtil.microsToMillis(args.getLong(1)); + String location = args.isNullAt(2) ? null : args.getString(2); + boolean dryRun = args.isNullAt(3) ? false : args.getBoolean(3); + Integer maxConcurrentDeletes = args.isNullAt(4) ? null : args.getInt(4); + String fileListView = args.isNullAt(5) ? null : args.getString(5); + + Preconditions.checkArgument( + maxConcurrentDeletes == null || maxConcurrentDeletes > 0, + "max_concurrent_deletes should have value > 0, value: %s", + maxConcurrentDeletes); + + Map equalSchemes = Maps.newHashMap(); + if (!args.isNullAt(6)) { + args.getMap(6) + .foreach( + DataTypes.StringType, + DataTypes.StringType, + (k, v) -> { + equalSchemes.put(k.toString(), v.toString()); + return BoxedUnit.UNIT; + }); + } + + Map equalAuthorities = Maps.newHashMap(); + if (!args.isNullAt(7)) { + args.getMap(7) + .foreach( + DataTypes.StringType, + DataTypes.StringType, + (k, v) -> { + equalAuthorities.put(k.toString(), v.toString()); + return BoxedUnit.UNIT; + }); + } + + PrefixMismatchMode prefixMismatchMode = + args.isNullAt(8) ? null : PrefixMismatchMode.fromString(args.getString(8)); + + return withIcebergTable( + tableIdent, + table -> { + DeleteOrphanFilesSparkAction action = actions().deleteOrphanFiles(table); + + if (olderThanMillis != null) { + boolean isTesting = Boolean.parseBoolean(spark().conf().get("spark.testing", "false")); + if (!isTesting) { + validateInterval(olderThanMillis); + } + action.olderThan(olderThanMillis); + } + + if (location != null) { + action.location(location); + } + + if (dryRun) { + action.deleteWith(file -> {}); + } + + if (maxConcurrentDeletes != null) { + if (table.io() instanceof SupportsBulkOperations) { + LOG.warn( + "max_concurrent_deletes only works with FileIOs that do not support bulk deletes. This" + + "table is currently using {} which supports bulk deletes so the parameter will be ignored. " + + "See that IO's documentation to learn how to adjust parallelism for that particular " + + "IO's bulk delete.", + table.io().getClass().getName()); + } else { + + action.executeDeleteWith(executorService(maxConcurrentDeletes, "remove-orphans")); + } + } + + if (fileListView != null) { + action.compareToFileList(spark().table(fileListView)); + } + + action.equalSchemes(equalSchemes); + action.equalAuthorities(equalAuthorities); + + if (prefixMismatchMode != null) { + action.prefixMismatchMode(prefixMismatchMode); + } + + DeleteOrphanFiles.Result result = action.execute(); + + return toOutputRows(result); + }); + } + + private InternalRow[] toOutputRows(DeleteOrphanFiles.Result result) { + Iterable orphanFileLocations = result.orphanFileLocations(); + + int orphanFileLocationsCount = Iterables.size(orphanFileLocations); + InternalRow[] rows = new InternalRow[orphanFileLocationsCount]; + + int index = 0; + for (String fileLocation : orphanFileLocations) { + rows[index] = newInternalRow(UTF8String.fromString(fileLocation)); + index++; + } + + return rows; + } + + private void validateInterval(long olderThanMillis) { + long intervalMillis = System.currentTimeMillis() - olderThanMillis; + if (intervalMillis < TimeUnit.DAYS.toMillis(1)) { + throw new IllegalArgumentException( + "Cannot remove orphan files with an interval less than 24 hours. Executing this " + + "procedure with a short interval may corrupt the table if other operations are happening " + + "at the same time. If you are absolutely confident that no concurrent operations will be " + + "affected by removing orphan files with such a short interval, you can use the Action API " + + "to remove orphan files with an arbitrary interval."); + } + } + + @Override + public String description() { + return "RemoveOrphanFilesProcedure"; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/RewriteDataFilesProcedure.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/RewriteDataFilesProcedure.java new file mode 100644 index 000000000000..bb6ba393a327 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/RewriteDataFilesProcedure.java @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.RewriteDataFiles; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.NamedReference; +import org.apache.iceberg.expressions.Zorder; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.ExtendedParser; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * A procedure that rewrites datafiles in a table. + * + * @see org.apache.iceberg.spark.actions.SparkActions#rewriteDataFiles(Table) + */ +class RewriteDataFilesProcedure extends BaseProcedure { + + private static final ProcedureParameter TABLE_PARAM = + ProcedureParameter.required("table", DataTypes.StringType); + private static final ProcedureParameter STRATEGY_PARAM = + ProcedureParameter.optional("strategy", DataTypes.StringType); + private static final ProcedureParameter SORT_ORDER_PARAM = + ProcedureParameter.optional("sort_order", DataTypes.StringType); + private static final ProcedureParameter OPTIONS_PARAM = + ProcedureParameter.optional("options", STRING_MAP); + private static final ProcedureParameter WHERE_PARAM = + ProcedureParameter.optional("where", DataTypes.StringType); + + private static final ProcedureParameter[] PARAMETERS = + new ProcedureParameter[] { + TABLE_PARAM, STRATEGY_PARAM, SORT_ORDER_PARAM, OPTIONS_PARAM, WHERE_PARAM + }; + + // counts are not nullable since the action result is never null + private static final StructType OUTPUT_TYPE = + new StructType( + new StructField[] { + new StructField( + "rewritten_data_files_count", DataTypes.IntegerType, false, Metadata.empty()), + new StructField( + "added_data_files_count", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("rewritten_bytes_count", DataTypes.LongType, false, Metadata.empty()), + new StructField( + "failed_data_files_count", DataTypes.IntegerType, false, Metadata.empty()) + }); + + public static ProcedureBuilder builder() { + return new Builder() { + @Override + protected RewriteDataFilesProcedure doBuild() { + return new RewriteDataFilesProcedure(tableCatalog()); + } + }; + } + + private RewriteDataFilesProcedure(TableCatalog tableCatalog) { + super(tableCatalog); + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + ProcedureInput input = new ProcedureInput(spark(), tableCatalog(), PARAMETERS, args); + Identifier tableIdent = input.ident(TABLE_PARAM); + String strategy = input.asString(STRATEGY_PARAM, null); + String sortOrderString = input.asString(SORT_ORDER_PARAM, null); + Map options = input.asStringMap(OPTIONS_PARAM, ImmutableMap.of()); + String where = input.asString(WHERE_PARAM, null); + + return modifyIcebergTable( + tableIdent, + table -> { + RewriteDataFiles action = actions().rewriteDataFiles(table).options(options); + + if (strategy != null || sortOrderString != null) { + action = checkAndApplyStrategy(action, strategy, sortOrderString, table.schema()); + } + + action = checkAndApplyFilter(action, where, tableIdent); + + RewriteDataFiles.Result result = action.execute(); + + return toOutputRows(result); + }); + } + + private RewriteDataFiles checkAndApplyFilter( + RewriteDataFiles action, String where, Identifier ident) { + if (where != null) { + Expression expression = filterExpression(ident, where); + return action.filter(expression); + } + return action; + } + + private RewriteDataFiles checkAndApplyStrategy( + RewriteDataFiles action, String strategy, String sortOrderString, Schema schema) { + List zOrderTerms = Lists.newArrayList(); + List sortOrderFields = Lists.newArrayList(); + if (sortOrderString != null) { + ExtendedParser.parseSortOrder(spark(), sortOrderString) + .forEach( + field -> { + if (field.term() instanceof Zorder) { + zOrderTerms.add((Zorder) field.term()); + } else { + sortOrderFields.add(field); + } + }); + + if (!zOrderTerms.isEmpty() && !sortOrderFields.isEmpty()) { + // TODO: we need to allow this in future when SparkAction has handling for this. + throw new IllegalArgumentException( + "Cannot mix identity sort columns and a Zorder sort expression: " + sortOrderString); + } + } + + // caller of this function ensures that between strategy and sortOrder, at least one of them is + // not null. + if (strategy == null || strategy.equalsIgnoreCase("sort")) { + if (!zOrderTerms.isEmpty()) { + String[] columnNames = + zOrderTerms.stream() + .flatMap(zOrder -> zOrder.refs().stream().map(NamedReference::name)) + .toArray(String[]::new); + return action.zOrder(columnNames); + } else if (!sortOrderFields.isEmpty()) { + return action.sort(buildSortOrder(sortOrderFields, schema)); + } else { + return action.sort(); + } + } + if (strategy.equalsIgnoreCase("binpack")) { + RewriteDataFiles rewriteDataFiles = action.binPack(); + if (sortOrderString != null) { + // calling below method to throw the error as user has set both binpack strategy and sort + // order + return rewriteDataFiles.sort(buildSortOrder(sortOrderFields, schema)); + } + return rewriteDataFiles; + } else { + throw new IllegalArgumentException( + "unsupported strategy: " + strategy + ". Only binpack or sort is supported"); + } + } + + private SortOrder buildSortOrder( + List rawOrderFields, Schema schema) { + SortOrder.Builder builder = SortOrder.builderFor(schema); + rawOrderFields.forEach( + rawField -> builder.sortBy(rawField.term(), rawField.direction(), rawField.nullOrder())); + return builder.build(); + } + + private InternalRow[] toOutputRows(RewriteDataFiles.Result result) { + int rewrittenDataFilesCount = result.rewrittenDataFilesCount(); + long rewrittenBytesCount = result.rewrittenBytesCount(); + int addedDataFilesCount = result.addedDataFilesCount(); + int failedDataFilesCount = result.failedDataFilesCount(); + + InternalRow row = + newInternalRow( + rewrittenDataFilesCount, + addedDataFilesCount, + rewrittenBytesCount, + failedDataFilesCount); + return new InternalRow[] {row}; + } + + @Override + public String description() { + return "RewriteDataFilesProcedure"; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/RewriteManifestsProcedure.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/RewriteManifestsProcedure.java new file mode 100644 index 000000000000..e59077ae3da9 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/RewriteManifestsProcedure.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.RewriteManifests; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.spark.actions.RewriteManifestsSparkAction; +import org.apache.iceberg.spark.actions.SparkActions; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * A procedure that rewrites manifests in a table. + * + *

Note: this procedure invalidates all cached Spark plans that reference the affected + * table. + * + * @see SparkActions#rewriteManifests(Table) () + */ +class RewriteManifestsProcedure extends BaseProcedure { + + private static final ProcedureParameter[] PARAMETERS = + new ProcedureParameter[] { + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.optional("use_caching", DataTypes.BooleanType), + ProcedureParameter.optional("spec_id", DataTypes.IntegerType) + }; + + // counts are not nullable since the action result is never null + private static final StructType OUTPUT_TYPE = + new StructType( + new StructField[] { + new StructField( + "rewritten_manifests_count", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("added_manifests_count", DataTypes.IntegerType, false, Metadata.empty()) + }); + + public static ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + protected RewriteManifestsProcedure doBuild() { + return new RewriteManifestsProcedure(tableCatalog()); + } + }; + } + + private RewriteManifestsProcedure(TableCatalog tableCatalog) { + super(tableCatalog); + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); + Boolean useCaching = args.isNullAt(1) ? null : args.getBoolean(1); + Integer specId = args.isNullAt(2) ? null : args.getInt(2); + + return modifyIcebergTable( + tableIdent, + table -> { + RewriteManifestsSparkAction action = actions().rewriteManifests(table); + + if (useCaching != null) { + action.option(RewriteManifestsSparkAction.USE_CACHING, useCaching.toString()); + } + + if (specId != null) { + action.specId(specId); + } + + RewriteManifests.Result result = action.execute(); + + return toOutputRows(result); + }); + } + + private InternalRow[] toOutputRows(RewriteManifests.Result result) { + int rewrittenManifestsCount = Iterables.size(result.rewrittenManifests()); + int addedManifestsCount = Iterables.size(result.addedManifests()); + InternalRow row = newInternalRow(rewrittenManifestsCount, addedManifestsCount); + return new InternalRow[] {row}; + } + + @Override + public String description() { + return "RewriteManifestsProcedure"; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/RewritePositionDeleteFilesProcedure.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/RewritePositionDeleteFilesProcedure.java new file mode 100644 index 000000000000..3d5e45ce8b89 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/RewritePositionDeleteFilesProcedure.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import java.util.Map; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.RewritePositionDeleteFiles; +import org.apache.iceberg.actions.RewritePositionDeleteFiles.Result; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * A procedure that rewrites position delete files in a table. + * + * @see org.apache.iceberg.spark.actions.SparkActions#rewritePositionDeletes(Table) + */ +public class RewritePositionDeleteFilesProcedure extends BaseProcedure { + + private static final ProcedureParameter TABLE_PARAM = + ProcedureParameter.required("table", DataTypes.StringType); + private static final ProcedureParameter OPTIONS_PARAM = + ProcedureParameter.optional("options", STRING_MAP); + private static final ProcedureParameter WHERE_PARAM = + ProcedureParameter.optional("where", DataTypes.StringType); + + private static final ProcedureParameter[] PARAMETERS = + new ProcedureParameter[] {TABLE_PARAM, OPTIONS_PARAM, WHERE_PARAM}; + + private static final StructType OUTPUT_TYPE = + new StructType( + new StructField[] { + new StructField( + "rewritten_delete_files_count", DataTypes.IntegerType, false, Metadata.empty()), + new StructField( + "added_delete_files_count", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("rewritten_bytes_count", DataTypes.LongType, false, Metadata.empty()), + new StructField("added_bytes_count", DataTypes.LongType, false, Metadata.empty()) + }); + + public static SparkProcedures.ProcedureBuilder builder() { + return new Builder() { + @Override + protected RewritePositionDeleteFilesProcedure doBuild() { + return new RewritePositionDeleteFilesProcedure(tableCatalog()); + } + }; + } + + private RewritePositionDeleteFilesProcedure(TableCatalog tableCatalog) { + super(tableCatalog); + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + ProcedureInput input = new ProcedureInput(spark(), tableCatalog(), PARAMETERS, args); + Identifier tableIdent = input.ident(TABLE_PARAM); + Map options = input.asStringMap(OPTIONS_PARAM, ImmutableMap.of()); + String where = input.asString(WHERE_PARAM, null); + + return modifyIcebergTable( + tableIdent, + table -> { + RewritePositionDeleteFiles action = + actions().rewritePositionDeletes(table).options(options); + + if (where != null) { + Expression whereExpression = filterExpression(tableIdent, where); + action = action.filter(whereExpression); + } + + Result result = action.execute(); + return new InternalRow[] {toOutputRow(result)}; + }); + } + + private InternalRow toOutputRow(Result result) { + return newInternalRow( + result.rewrittenDeleteFilesCount(), + result.addedDeleteFilesCount(), + result.rewrittenBytesCount(), + result.addedBytesCount()); + } + + @Override + public String description() { + return "RewritePositionDeleteFilesProcedure"; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/RollbackToSnapshotProcedure.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/RollbackToSnapshotProcedure.java new file mode 100644 index 000000000000..49cc1a5ceae3 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/RollbackToSnapshotProcedure.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * A procedure that rollbacks a table to a specific snapshot id. + * + *

Note: this procedure invalidates all cached Spark plans that reference the affected + * table. + * + * @see org.apache.iceberg.ManageSnapshots#rollbackTo(long) + */ +class RollbackToSnapshotProcedure extends BaseProcedure { + + private static final ProcedureParameter[] PARAMETERS = + new ProcedureParameter[] { + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.required("snapshot_id", DataTypes.LongType) + }; + + private static final StructType OUTPUT_TYPE = + new StructType( + new StructField[] { + new StructField("previous_snapshot_id", DataTypes.LongType, false, Metadata.empty()), + new StructField("current_snapshot_id", DataTypes.LongType, false, Metadata.empty()) + }); + + public static ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + public RollbackToSnapshotProcedure doBuild() { + return new RollbackToSnapshotProcedure(tableCatalog()); + } + }; + } + + private RollbackToSnapshotProcedure(TableCatalog tableCatalog) { + super(tableCatalog); + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); + long snapshotId = args.getLong(1); + + return modifyIcebergTable( + tableIdent, + table -> { + Snapshot previousSnapshot = table.currentSnapshot(); + + table.manageSnapshots().rollbackTo(snapshotId).commit(); + + InternalRow outputRow = newInternalRow(previousSnapshot.snapshotId(), snapshotId); + return new InternalRow[] {outputRow}; + }); + } + + @Override + public String description() { + return "RollbackToSnapshotProcedure"; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/RollbackToTimestampProcedure.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/RollbackToTimestampProcedure.java new file mode 100644 index 000000000000..059725f0c152 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/RollbackToTimestampProcedure.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.iceberg.util.DateTimeUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * A procedure that rollbacks a table to a given point in time. + * + *

Note: this procedure invalidates all cached Spark plans that reference the affected + * table. + * + * @see org.apache.iceberg.ManageSnapshots#rollbackToTime(long) + */ +class RollbackToTimestampProcedure extends BaseProcedure { + + private static final ProcedureParameter[] PARAMETERS = + new ProcedureParameter[] { + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.required("timestamp", DataTypes.TimestampType) + }; + + private static final StructType OUTPUT_TYPE = + new StructType( + new StructField[] { + new StructField("previous_snapshot_id", DataTypes.LongType, false, Metadata.empty()), + new StructField("current_snapshot_id", DataTypes.LongType, false, Metadata.empty()) + }); + + public static ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + protected RollbackToTimestampProcedure doBuild() { + return new RollbackToTimestampProcedure(tableCatalog()); + } + }; + } + + private RollbackToTimestampProcedure(TableCatalog catalog) { + super(catalog); + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); + // timestamps in Spark have microsecond precision so this conversion is lossy + long timestampMillis = DateTimeUtil.microsToMillis(args.getLong(1)); + + return modifyIcebergTable( + tableIdent, + table -> { + Snapshot previousSnapshot = table.currentSnapshot(); + + table.manageSnapshots().rollbackToTime(timestampMillis).commit(); + + Snapshot currentSnapshot = table.currentSnapshot(); + + InternalRow outputRow = + newInternalRow(previousSnapshot.snapshotId(), currentSnapshot.snapshotId()); + return new InternalRow[] {outputRow}; + }); + } + + @Override + public String description() { + return "RollbackToTimestampProcedure"; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/SetCurrentSnapshotProcedure.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/SetCurrentSnapshotProcedure.java new file mode 100644 index 000000000000..22719e43c057 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/SetCurrentSnapshotProcedure.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotRef; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * A procedure that sets the current snapshot in a table. + * + *

Note: this procedure invalidates all cached Spark plans that reference the affected + * table. + * + * @see org.apache.iceberg.ManageSnapshots#setCurrentSnapshot(long) + */ +class SetCurrentSnapshotProcedure extends BaseProcedure { + + private static final ProcedureParameter[] PARAMETERS = + new ProcedureParameter[] { + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.optional("snapshot_id", DataTypes.LongType), + ProcedureParameter.optional("ref", DataTypes.StringType) + }; + + private static final StructType OUTPUT_TYPE = + new StructType( + new StructField[] { + new StructField("previous_snapshot_id", DataTypes.LongType, true, Metadata.empty()), + new StructField("current_snapshot_id", DataTypes.LongType, false, Metadata.empty()) + }); + + public static ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + protected SetCurrentSnapshotProcedure doBuild() { + return new SetCurrentSnapshotProcedure(tableCatalog()); + } + }; + } + + private SetCurrentSnapshotProcedure(TableCatalog catalog) { + super(catalog); + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); + Long snapshotId = args.isNullAt(1) ? null : args.getLong(1); + String ref = args.isNullAt(2) ? null : args.getString(2); + Preconditions.checkArgument( + (snapshotId != null && ref == null) || (snapshotId == null && ref != null), + "Either snapshot_id or ref must be provided, not both"); + + return modifyIcebergTable( + tableIdent, + table -> { + Snapshot previousSnapshot = table.currentSnapshot(); + Long previousSnapshotId = previousSnapshot != null ? previousSnapshot.snapshotId() : null; + + long targetSnapshotId = snapshotId != null ? snapshotId : toSnapshotId(table, ref); + table.manageSnapshots().setCurrentSnapshot(targetSnapshotId).commit(); + + InternalRow outputRow = newInternalRow(previousSnapshotId, targetSnapshotId); + return new InternalRow[] {outputRow}; + }); + } + + @Override + public String description() { + return "SetCurrentSnapshotProcedure"; + } + + private long toSnapshotId(Table table, String refName) { + SnapshotRef ref = table.refs().get(refName); + ValidationException.check(ref != null, "Cannot find matching snapshot ID for ref " + refName); + return ref.snapshotId(); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/SnapshotTableProcedure.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/SnapshotTableProcedure.java new file mode 100644 index 000000000000..f709f64ebf62 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/SnapshotTableProcedure.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import java.util.Map; +import org.apache.iceberg.actions.SnapshotTable; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.actions.SparkActions; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import scala.runtime.BoxedUnit; + +class SnapshotTableProcedure extends BaseProcedure { + private static final ProcedureParameter[] PARAMETERS = + new ProcedureParameter[] { + ProcedureParameter.required("source_table", DataTypes.StringType), + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.optional("location", DataTypes.StringType), + ProcedureParameter.optional("properties", STRING_MAP), + ProcedureParameter.optional("parallelism", DataTypes.IntegerType) + }; + + private static final StructType OUTPUT_TYPE = + new StructType( + new StructField[] { + new StructField("imported_files_count", DataTypes.LongType, false, Metadata.empty()) + }); + + private SnapshotTableProcedure(TableCatalog tableCatalog) { + super(tableCatalog); + } + + public static SparkProcedures.ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + protected SnapshotTableProcedure doBuild() { + return new SnapshotTableProcedure(tableCatalog()); + } + }; + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + String source = args.getString(0); + Preconditions.checkArgument( + source != null && !source.isEmpty(), + "Cannot handle an empty identifier for argument source_table"); + String dest = args.getString(1); + Preconditions.checkArgument( + dest != null && !dest.isEmpty(), "Cannot handle an empty identifier for argument table"); + String snapshotLocation = args.isNullAt(2) ? null : args.getString(2); + + Map properties = Maps.newHashMap(); + if (!args.isNullAt(3)) { + args.getMap(3) + .foreach( + DataTypes.StringType, + DataTypes.StringType, + (k, v) -> { + properties.put(k.toString(), v.toString()); + return BoxedUnit.UNIT; + }); + } + + Preconditions.checkArgument( + !source.equals(dest), + "Cannot create a snapshot with the same name as the source of the snapshot."); + SnapshotTable action = SparkActions.get().snapshotTable(source).as(dest); + + if (snapshotLocation != null) { + action.tableLocation(snapshotLocation); + } + + if (!args.isNullAt(4)) { + int parallelism = args.getInt(4); + Preconditions.checkArgument(parallelism > 0, "Parallelism should be larger than 0"); + action = action.executeWith(executorService(parallelism, "table-snapshot")); + } + + SnapshotTable.Result result = action.tableProperties(properties).execute(); + return new InternalRow[] {newInternalRow(result.importedDataFilesCount())}; + } + + @Override + public String description() { + return "SnapshotTableProcedure"; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/SparkProcedures.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/SparkProcedures.java new file mode 100644 index 000000000000..42003b24e94c --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/procedures/SparkProcedures.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.function.Supplier; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.Procedure; + +public class SparkProcedures { + + private static final Map> BUILDERS = initProcedureBuilders(); + + private SparkProcedures() {} + + public static ProcedureBuilder newBuilder(String name) { + // procedure resolution is case insensitive to match the existing Spark behavior for functions + Supplier builderSupplier = BUILDERS.get(name.toLowerCase(Locale.ROOT)); + return builderSupplier != null ? builderSupplier.get() : null; + } + + public static Set names() { + return BUILDERS.keySet(); + } + + private static Map> initProcedureBuilders() { + ImmutableMap.Builder> mapBuilder = ImmutableMap.builder(); + mapBuilder.put("rollback_to_snapshot", RollbackToSnapshotProcedure::builder); + mapBuilder.put("rollback_to_timestamp", RollbackToTimestampProcedure::builder); + mapBuilder.put("set_current_snapshot", SetCurrentSnapshotProcedure::builder); + mapBuilder.put("cherrypick_snapshot", CherrypickSnapshotProcedure::builder); + mapBuilder.put("rewrite_data_files", RewriteDataFilesProcedure::builder); + mapBuilder.put("rewrite_manifests", RewriteManifestsProcedure::builder); + mapBuilder.put("remove_orphan_files", RemoveOrphanFilesProcedure::builder); + mapBuilder.put("expire_snapshots", ExpireSnapshotsProcedure::builder); + mapBuilder.put("migrate", MigrateTableProcedure::builder); + mapBuilder.put("snapshot", SnapshotTableProcedure::builder); + mapBuilder.put("add_files", AddFilesProcedure::builder); + mapBuilder.put("ancestors_of", AncestorsOfProcedure::builder); + mapBuilder.put("register_table", RegisterTableProcedure::builder); + mapBuilder.put("publish_changes", PublishChangesProcedure::builder); + mapBuilder.put("create_changelog_view", CreateChangelogViewProcedure::builder); + mapBuilder.put("rewrite_position_delete_files", RewritePositionDeleteFilesProcedure::builder); + mapBuilder.put("fast_forward", FastForwardBranchProcedure::builder); + return mapBuilder.build(); + } + + public interface ProcedureBuilder { + ProcedureBuilder withTableCatalog(TableCatalog tableCatalog); + + Procedure build(); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/BaseBatchReader.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/BaseBatchReader.java new file mode 100644 index 000000000000..49c43952135c --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/BaseBatchReader.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.ScanTask; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.orc.ORC; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.data.vectorized.VectorizedSparkOrcReaders; +import org.apache.iceberg.spark.data.vectorized.VectorizedSparkParquetReaders; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +abstract class BaseBatchReader extends BaseReader { + private final int batchSize; + + BaseBatchReader( + Table table, + ScanTaskGroup taskGroup, + Schema tableSchema, + Schema expectedSchema, + boolean caseSensitive, + int batchSize) { + super(table, taskGroup, tableSchema, expectedSchema, caseSensitive); + this.batchSize = batchSize; + } + + protected CloseableIterable newBatchIterable( + InputFile inputFile, + FileFormat format, + long start, + long length, + Expression residual, + Map idToConstant, + SparkDeleteFilter deleteFilter) { + switch (format) { + case PARQUET: + return newParquetIterable(inputFile, start, length, residual, idToConstant, deleteFilter); + + case ORC: + return newOrcIterable(inputFile, start, length, residual, idToConstant); + + default: + throw new UnsupportedOperationException( + "Format: " + format + " not supported for batched reads"); + } + } + + private CloseableIterable newParquetIterable( + InputFile inputFile, + long start, + long length, + Expression residual, + Map idToConstant, + SparkDeleteFilter deleteFilter) { + // get required schema if there are deletes + Schema requiredSchema = deleteFilter != null ? deleteFilter.requiredSchema() : expectedSchema(); + boolean hasPositionDelete = deleteFilter != null ? deleteFilter.hasPosDeletes() : false; + Schema projectedSchema = requiredSchema; + if (hasPositionDelete) { + // We need to add MetadataColumns.ROW_POSITION in the schema for + // ReadConf.generateOffsetToStartPos(Schema schema). This is not needed any + // more after #10107 is merged. + List columns = Lists.newArrayList(requiredSchema.columns()); + if (!columns.contains(MetadataColumns.ROW_POSITION)) { + columns.add(MetadataColumns.ROW_POSITION); + projectedSchema = new Schema(columns); + } + } + + return Parquet.read(inputFile) + .project(projectedSchema) + .split(start, length) + .createBatchedReaderFunc( + fileSchema -> + VectorizedSparkParquetReaders.buildReader( + requiredSchema, fileSchema, idToConstant, deleteFilter)) + .recordsPerBatch(batchSize) + .filter(residual) + .caseSensitive(caseSensitive()) + // Spark eagerly consumes the batches. So the underlying memory allocated could be reused + // without worrying about subsequent reads clobbering over each other. This improves + // read performance as every batch read doesn't have to pay the cost of allocating memory. + .reuseContainers() + .withNameMapping(nameMapping()) + .build(); + } + + private CloseableIterable newOrcIterable( + InputFile inputFile, + long start, + long length, + Expression residual, + Map idToConstant) { + Set constantFieldIds = idToConstant.keySet(); + Set metadataFieldIds = MetadataColumns.metadataFieldIds(); + Sets.SetView constantAndMetadataFieldIds = + Sets.union(constantFieldIds, metadataFieldIds); + Schema schemaWithoutConstantAndMetadataFields = + TypeUtil.selectNot(expectedSchema(), constantAndMetadataFieldIds); + + return ORC.read(inputFile) + .project(schemaWithoutConstantAndMetadataFields) + .split(start, length) + .createBatchedReaderFunc( + fileSchema -> + VectorizedSparkOrcReaders.buildReader(expectedSchema(), fileSchema, idToConstant)) + .recordsPerBatch(batchSize) + .filter(residual) + .caseSensitive(caseSensitive()) + .withNameMapping(nameMapping()) + .build(); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/BaseReader.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/BaseReader.java new file mode 100644 index 000000000000..f8e8a1f1dd6b --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/BaseReader.java @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.io.Closeable; +import java.io.IOException; +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.avro.generic.GenericData; +import org.apache.avro.util.Utf8; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.ContentScanTask; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Partitioning; +import org.apache.iceberg.ScanTask; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.data.BaseDeleteLoader; +import org.apache.iceberg.data.DeleteFilter; +import org.apache.iceberg.data.DeleteLoader; +import org.apache.iceberg.deletes.DeleteCounter; +import org.apache.iceberg.encryption.EncryptingFileIO; +import org.apache.iceberg.io.CloseableIterator; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.mapping.NameMapping; +import org.apache.iceberg.mapping.NameMappingParser; +import org.apache.iceberg.spark.SparkExecutorCache; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types.NestedField; +import org.apache.iceberg.types.Types.StructType; +import org.apache.iceberg.util.ByteBuffers; +import org.apache.iceberg.util.PartitionUtil; +import org.apache.spark.rdd.InputFileBlockHolder; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.UTF8String; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Base class of Spark readers. + * + * @param is the Java class returned by this reader whose objects contain one or more rows. + */ +abstract class BaseReader implements Closeable { + private static final Logger LOG = LoggerFactory.getLogger(BaseReader.class); + + private final Table table; + private final Schema tableSchema; + private final Schema expectedSchema; + private final boolean caseSensitive; + private final NameMapping nameMapping; + private final ScanTaskGroup taskGroup; + private final Iterator tasks; + private final DeleteCounter counter; + + private Map lazyInputFiles; + private CloseableIterator currentIterator; + private T current = null; + private TaskT currentTask = null; + + BaseReader( + Table table, + ScanTaskGroup taskGroup, + Schema tableSchema, + Schema expectedSchema, + boolean caseSensitive) { + this.table = table; + this.taskGroup = taskGroup; + this.tasks = taskGroup.tasks().iterator(); + this.currentIterator = CloseableIterator.empty(); + this.tableSchema = tableSchema; + this.expectedSchema = expectedSchema; + this.caseSensitive = caseSensitive; + String nameMappingString = table.properties().get(TableProperties.DEFAULT_NAME_MAPPING); + this.nameMapping = + nameMappingString != null ? NameMappingParser.fromJson(nameMappingString) : null; + this.counter = new DeleteCounter(); + } + + protected abstract CloseableIterator open(TaskT task); + + protected abstract Stream> referencedFiles(TaskT task); + + protected Schema expectedSchema() { + return expectedSchema; + } + + protected boolean caseSensitive() { + return caseSensitive; + } + + protected NameMapping nameMapping() { + return nameMapping; + } + + protected Table table() { + return table; + } + + protected DeleteCounter counter() { + return counter; + } + + public boolean next() throws IOException { + try { + while (true) { + if (currentIterator.hasNext()) { + this.current = currentIterator.next(); + return true; + } else if (tasks.hasNext()) { + this.currentIterator.close(); + this.currentTask = tasks.next(); + this.currentIterator = open(currentTask); + } else { + this.currentIterator.close(); + return false; + } + } + } catch (IOException | RuntimeException e) { + if (currentTask != null && !currentTask.isDataTask()) { + String filePaths = + referencedFiles(currentTask) + .map(ContentFile::location) + .collect(Collectors.joining(", ")); + LOG.error("Error reading file(s): {}", filePaths, e); + } + throw e; + } + } + + public T get() { + return current; + } + + @Override + public void close() throws IOException { + InputFileBlockHolder.unset(); + + // close the current iterator + this.currentIterator.close(); + + // exhaust the task iterator + while (tasks.hasNext()) { + tasks.next(); + } + } + + protected InputFile getInputFile(String location) { + return inputFiles().get(location); + } + + private Map inputFiles() { + if (lazyInputFiles == null) { + this.lazyInputFiles = + EncryptingFileIO.combine(table().io(), table().encryption()) + .bulkDecrypt( + () -> taskGroup.tasks().stream().flatMap(this::referencedFiles).iterator()); + } + + return lazyInputFiles; + } + + protected Map constantsMap(ContentScanTask task, Schema readSchema) { + if (readSchema.findField(MetadataColumns.PARTITION_COLUMN_ID) != null) { + StructType partitionType = Partitioning.partitionType(table); + return PartitionUtil.constantsMap(task, partitionType, BaseReader::convertConstant); + } else { + return PartitionUtil.constantsMap(task, BaseReader::convertConstant); + } + } + + protected static Object convertConstant(Type type, Object value) { + if (value == null) { + return null; + } + + switch (type.typeId()) { + case DECIMAL: + return Decimal.apply((BigDecimal) value); + case STRING: + if (value instanceof Utf8) { + Utf8 utf8 = (Utf8) value; + return UTF8String.fromBytes(utf8.getBytes(), 0, utf8.getByteLength()); + } + return UTF8String.fromString(value.toString()); + case FIXED: + if (value instanceof byte[]) { + return value; + } else if (value instanceof GenericData.Fixed) { + return ((GenericData.Fixed) value).bytes(); + } + return ByteBuffers.toByteArray((ByteBuffer) value); + case BINARY: + return ByteBuffers.toByteArray((ByteBuffer) value); + case STRUCT: + StructType structType = (StructType) type; + + if (structType.fields().isEmpty()) { + return new GenericInternalRow(); + } + + List fields = structType.fields(); + Object[] values = new Object[fields.size()]; + StructLike struct = (StructLike) value; + + for (int index = 0; index < fields.size(); index++) { + NestedField field = fields.get(index); + Type fieldType = field.type(); + values[index] = + convertConstant(fieldType, struct.get(index, fieldType.typeId().javaClass())); + } + + return new GenericInternalRow(values); + default: + } + return value; + } + + protected class SparkDeleteFilter extends DeleteFilter { + private final InternalRowWrapper asStructLike; + + SparkDeleteFilter( + String filePath, List deletes, DeleteCounter counter, boolean needRowPosCol) { + super(filePath, deletes, tableSchema, expectedSchema, counter, needRowPosCol); + this.asStructLike = + new InternalRowWrapper( + SparkSchemaUtil.convert(requiredSchema()), requiredSchema().asStruct()); + } + + @Override + protected StructLike asStructLike(InternalRow row) { + return asStructLike.wrap(row); + } + + @Override + protected InputFile getInputFile(String location) { + return BaseReader.this.getInputFile(location); + } + + @Override + protected void markRowDeleted(InternalRow row) { + if (!row.getBoolean(columnIsDeletedPosition())) { + row.setBoolean(columnIsDeletedPosition(), true); + counter().increment(); + } + } + + @Override + protected DeleteLoader newDeleteLoader() { + return new CachingDeleteLoader(this::loadInputFile); + } + + private class CachingDeleteLoader extends BaseDeleteLoader { + private final SparkExecutorCache cache; + + CachingDeleteLoader(Function loadInputFile) { + super(loadInputFile); + this.cache = SparkExecutorCache.getOrCreate(); + } + + @Override + protected boolean canCache(long size) { + return cache != null && size < cache.maxEntrySize(); + } + + @Override + protected V getOrLoad(String key, Supplier valueSupplier, long valueSize) { + return cache.getOrLoad(table().name(), key, valueSupplier, valueSize); + } + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/BaseRowReader.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/BaseRowReader.java new file mode 100644 index 000000000000..eb97185e21f1 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/BaseRowReader.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Map; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.ScanTask; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.orc.ORC; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.data.SparkOrcReader; +import org.apache.iceberg.spark.data.SparkParquetReaders; +import org.apache.iceberg.spark.data.SparkPlannedAvroReader; +import org.apache.iceberg.types.TypeUtil; +import org.apache.spark.sql.catalyst.InternalRow; + +abstract class BaseRowReader extends BaseReader { + BaseRowReader( + Table table, + ScanTaskGroup taskGroup, + Schema tableSchema, + Schema expectedSchema, + boolean caseSensitive) { + super(table, taskGroup, tableSchema, expectedSchema, caseSensitive); + } + + protected CloseableIterable newIterable( + InputFile file, + FileFormat format, + long start, + long length, + Expression residual, + Schema projection, + Map idToConstant) { + switch (format) { + case PARQUET: + return newParquetIterable(file, start, length, residual, projection, idToConstant); + + case AVRO: + return newAvroIterable(file, start, length, projection, idToConstant); + + case ORC: + return newOrcIterable(file, start, length, residual, projection, idToConstant); + + default: + throw new UnsupportedOperationException("Cannot read unknown format: " + format); + } + } + + private CloseableIterable newAvroIterable( + InputFile file, long start, long length, Schema projection, Map idToConstant) { + return Avro.read(file) + .reuseContainers() + .project(projection) + .split(start, length) + .createReaderFunc(readSchema -> SparkPlannedAvroReader.create(projection, idToConstant)) + .withNameMapping(nameMapping()) + .build(); + } + + private CloseableIterable newParquetIterable( + InputFile file, + long start, + long length, + Expression residual, + Schema readSchema, + Map idToConstant) { + return Parquet.read(file) + .reuseContainers() + .split(start, length) + .project(readSchema) + .createReaderFunc( + fileSchema -> SparkParquetReaders.buildReader(readSchema, fileSchema, idToConstant)) + .filter(residual) + .caseSensitive(caseSensitive()) + .withNameMapping(nameMapping()) + .build(); + } + + private CloseableIterable newOrcIterable( + InputFile file, + long start, + long length, + Expression residual, + Schema readSchema, + Map idToConstant) { + Schema readSchemaWithoutConstantAndMetadataFields = + TypeUtil.selectNot( + readSchema, Sets.union(idToConstant.keySet(), MetadataColumns.metadataFieldIds())); + + return ORC.read(file) + .project(readSchemaWithoutConstantAndMetadataFields) + .split(start, length) + .createReaderFunc( + readOrcSchema -> new SparkOrcReader(readSchema, readOrcSchema, idToConstant)) + .filter(residual) + .caseSensitive(caseSensitive()) + .withNameMapping(nameMapping()) + .build(); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/BatchDataReader.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/BatchDataReader.java new file mode 100644 index 000000000000..f45c152203ee --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/BatchDataReader.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Map; +import java.util.stream.Stream; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.io.CloseableIterator; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.source.metrics.TaskNumDeletes; +import org.apache.iceberg.spark.source.metrics.TaskNumSplits; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.rdd.InputFileBlockHolder; +import org.apache.spark.sql.connector.metric.CustomTaskMetric; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +class BatchDataReader extends BaseBatchReader + implements PartitionReader { + + private static final Logger LOG = LoggerFactory.getLogger(BatchDataReader.class); + + private final long numSplits; + + BatchDataReader(SparkInputPartition partition, int batchSize) { + this( + partition.table(), + partition.taskGroup(), + SnapshotUtil.schemaFor(partition.table(), partition.branch()), + partition.expectedSchema(), + partition.isCaseSensitive(), + batchSize); + } + + BatchDataReader( + Table table, + ScanTaskGroup taskGroup, + Schema tableSchema, + Schema expectedSchema, + boolean caseSensitive, + int size) { + super(table, taskGroup, tableSchema, expectedSchema, caseSensitive, size); + + numSplits = taskGroup.tasks().size(); + LOG.debug("Reading {} file split(s) for table {}", numSplits, table.name()); + } + + @Override + public CustomTaskMetric[] currentMetricsValues() { + return new CustomTaskMetric[] { + new TaskNumSplits(numSplits), new TaskNumDeletes(counter().get()) + }; + } + + @Override + protected Stream> referencedFiles(FileScanTask task) { + return Stream.concat(Stream.of(task.file()), task.deletes().stream()); + } + + @Override + protected CloseableIterator open(FileScanTask task) { + String filePath = task.file().location(); + LOG.debug("Opening data file {}", filePath); + + // update the current file for Spark's filename() function + InputFileBlockHolder.set(filePath, task.start(), task.length()); + + Map idToConstant = constantsMap(task, expectedSchema()); + + InputFile inputFile = getInputFile(filePath); + Preconditions.checkNotNull(inputFile, "Could not find InputFile associated with FileScanTask"); + + SparkDeleteFilter deleteFilter = + task.deletes().isEmpty() + ? null + : new SparkDeleteFilter(filePath, task.deletes(), counter(), false); + + return newBatchIterable( + inputFile, + task.file().format(), + task.start(), + task.length(), + task.residual(), + idToConstant, + deleteFilter) + .iterator(); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/ChangelogRowReader.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/ChangelogRowReader.java new file mode 100644 index 000000000000..c8e6f5679cd8 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/ChangelogRowReader.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; +import org.apache.iceberg.AddedRowsScanTask; +import org.apache.iceberg.ChangelogScanTask; +import org.apache.iceberg.ChangelogUtil; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.ContentScanTask; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.DeletedDataFileScanTask; +import org.apache.iceberg.DeletedRowsScanTask; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.CloseableIterator; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.rdd.InputFileBlockHolder; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.catalyst.expressions.JoinedRow; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.unsafe.types.UTF8String; + +class ChangelogRowReader extends BaseRowReader + implements PartitionReader { + + ChangelogRowReader(SparkInputPartition partition) { + this( + partition.table(), + partition.taskGroup(), + SnapshotUtil.schemaFor(partition.table(), partition.branch()), + partition.expectedSchema(), + partition.isCaseSensitive()); + } + + ChangelogRowReader( + Table table, + ScanTaskGroup taskGroup, + Schema tableSchema, + Schema expectedSchema, + boolean caseSensitive) { + super( + table, + taskGroup, + tableSchema, + ChangelogUtil.dropChangelogMetadata(expectedSchema), + caseSensitive); + } + + @Override + protected CloseableIterator open(ChangelogScanTask task) { + JoinedRow cdcRow = new JoinedRow(); + + cdcRow.withRight(changelogMetadata(task)); + + CloseableIterable rows = openChangelogScanTask(task); + CloseableIterable cdcRows = CloseableIterable.transform(rows, cdcRow::withLeft); + + return cdcRows.iterator(); + } + + private static InternalRow changelogMetadata(ChangelogScanTask task) { + InternalRow metadataRow = new GenericInternalRow(3); + + metadataRow.update(0, UTF8String.fromString(task.operation().name())); + metadataRow.update(1, task.changeOrdinal()); + metadataRow.update(2, task.commitSnapshotId()); + + return metadataRow; + } + + private CloseableIterable openChangelogScanTask(ChangelogScanTask task) { + if (task instanceof AddedRowsScanTask) { + return openAddedRowsScanTask((AddedRowsScanTask) task); + + } else if (task instanceof DeletedRowsScanTask) { + throw new UnsupportedOperationException("Deleted rows scan task is not supported yet"); + + } else if (task instanceof DeletedDataFileScanTask) { + return openDeletedDataFileScanTask((DeletedDataFileScanTask) task); + + } else { + throw new IllegalArgumentException( + "Unsupported changelog scan task type: " + task.getClass().getName()); + } + } + + CloseableIterable openAddedRowsScanTask(AddedRowsScanTask task) { + String filePath = task.file().location(); + SparkDeleteFilter deletes = new SparkDeleteFilter(filePath, task.deletes(), counter(), true); + return deletes.filter(rows(task, deletes.requiredSchema())); + } + + private CloseableIterable openDeletedDataFileScanTask(DeletedDataFileScanTask task) { + String filePath = task.file().path().toString(); + SparkDeleteFilter deletes = + new SparkDeleteFilter(filePath, task.existingDeletes(), counter(), true); + return deletes.filter(rows(task, deletes.requiredSchema())); + } + + private CloseableIterable rows(ContentScanTask task, Schema readSchema) { + Map idToConstant = constantsMap(task, readSchema); + + String filePath = task.file().path().toString(); + + // update the current file for Spark's filename() function + InputFileBlockHolder.set(filePath, task.start(), task.length()); + + InputFile location = getInputFile(filePath); + Preconditions.checkNotNull(location, "Could not find InputFile"); + return newIterable( + location, + task.file().format(), + task.start(), + task.length(), + task.residual(), + readSchema, + idToConstant); + } + + @Override + protected Stream> referencedFiles(ChangelogScanTask task) { + if (task instanceof AddedRowsScanTask) { + return addedRowsScanTaskFiles((AddedRowsScanTask) task); + + } else if (task instanceof DeletedRowsScanTask) { + throw new UnsupportedOperationException("Deleted rows scan task is not supported yet"); + + } else if (task instanceof DeletedDataFileScanTask) { + return deletedDataFileScanTaskFiles((DeletedDataFileScanTask) task); + + } else { + throw new IllegalArgumentException( + "Unsupported changelog scan task type: " + task.getClass().getName()); + } + } + + private static Stream> deletedDataFileScanTaskFiles(DeletedDataFileScanTask task) { + DataFile file = task.file(); + List existingDeletes = task.existingDeletes(); + return Stream.concat(Stream.of(file), existingDeletes.stream()); + } + + private static Stream> addedRowsScanTaskFiles(AddedRowsScanTask task) { + DataFile file = task.file(); + List deletes = task.deletes(); + return Stream.concat(Stream.of(file), deletes.stream()); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/EqualityDeleteRowReader.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/EqualityDeleteRowReader.java new file mode 100644 index 000000000000..ee9449ee13c8 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/EqualityDeleteRowReader.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Map; +import org.apache.iceberg.CombinedScanTask; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.io.CloseableIterator; +import org.apache.spark.rdd.InputFileBlockHolder; +import org.apache.spark.sql.catalyst.InternalRow; + +public class EqualityDeleteRowReader extends RowDataReader { + public EqualityDeleteRowReader( + CombinedScanTask task, + Table table, + Schema tableSchema, + Schema expectedSchema, + boolean caseSensitive) { + super(table, task, tableSchema, expectedSchema, caseSensitive); + } + + @Override + protected CloseableIterator open(FileScanTask task) { + SparkDeleteFilter matches = + new SparkDeleteFilter(task.file().location(), task.deletes(), counter(), true); + + // schema or rows returned by readers + Schema requiredSchema = matches.requiredSchema(); + Map idToConstant = constantsMap(task, expectedSchema()); + DataFile file = task.file(); + + // update the current file for Spark's filename() function + InputFileBlockHolder.set(file.location(), task.start(), task.length()); + + return matches.findEqualityDeleteRows(open(task, requiredSchema, idToConstant)).iterator(); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/HasIcebergCatalog.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/HasIcebergCatalog.java new file mode 100644 index 000000000000..37e0c4dfcdb6 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/HasIcebergCatalog.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.catalog.Catalog; +import org.apache.spark.sql.connector.catalog.TableCatalog; + +public interface HasIcebergCatalog extends TableCatalog { + + /** + * Returns the underlying {@link org.apache.iceberg.catalog.Catalog} backing this Spark Catalog + */ + Catalog icebergCatalog(); +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/IcebergSource.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/IcebergSource.java new file mode 100644 index 000000000000..e6edda85b499 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/IcebergSource.java @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Arrays; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Stream; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.PathIdentifier; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkCachedTableCatalog; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.iceberg.spark.SparkTableCache; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.connector.catalog.CatalogManager; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.SupportsCatalogOptions; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.sources.DataSourceRegister; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +/** + * The IcebergSource loads/writes tables with format "iceberg". It can load paths and tables. + * + *

How paths/tables are loaded when using spark.read().format("iceberg").load(table) + * + *

table = "file:///path/to/table" -> loads a HadoopTable at given path table = "tablename" + * -> loads currentCatalog.currentNamespace.tablename table = "catalog.tablename" -> load + * "tablename" from the specified catalog. table = "namespace.tablename" -> load + * "namespace.tablename" from current catalog table = "catalog.namespace.tablename" -> + * "namespace.tablename" from the specified catalog. table = "namespace1.namespace2.tablename" -> + * load "namespace1.namespace2.tablename" from current catalog + * + *

The above list is in order of priority. For example: a matching catalog will take priority + * over any namespace resolution. + */ +public class IcebergSource implements DataSourceRegister, SupportsCatalogOptions { + private static final String DEFAULT_CATALOG_NAME = "default_iceberg"; + private static final String DEFAULT_CACHE_CATALOG_NAME = "default_cache_iceberg"; + private static final String DEFAULT_CATALOG = "spark.sql.catalog." + DEFAULT_CATALOG_NAME; + private static final String DEFAULT_CACHE_CATALOG = + "spark.sql.catalog." + DEFAULT_CACHE_CATALOG_NAME; + private static final String AT_TIMESTAMP = "at_timestamp_"; + private static final String SNAPSHOT_ID = "snapshot_id_"; + private static final String BRANCH_PREFIX = "branch_"; + private static final String TAG_PREFIX = "tag_"; + private static final String[] EMPTY_NAMESPACE = new String[0]; + + private static final SparkTableCache TABLE_CACHE = SparkTableCache.get(); + + @Override + public String shortName() { + return "iceberg"; + } + + @Override + public StructType inferSchema(CaseInsensitiveStringMap options) { + return null; + } + + @Override + public Transform[] inferPartitioning(CaseInsensitiveStringMap options) { + return getTable(null, null, options).partitioning(); + } + + @Override + public boolean supportsExternalMetadata() { + return true; + } + + @Override + public Table getTable(StructType schema, Transform[] partitioning, Map options) { + Spark3Util.CatalogAndIdentifier catalogIdentifier = + catalogAndIdentifier(new CaseInsensitiveStringMap(options)); + CatalogPlugin catalog = catalogIdentifier.catalog(); + Identifier ident = catalogIdentifier.identifier(); + + try { + if (catalog instanceof TableCatalog) { + return ((TableCatalog) catalog).loadTable(ident); + } + } catch (NoSuchTableException e) { + // throwing an iceberg NoSuchTableException because the Spark one is typed and cant be thrown + // from this interface + throw new org.apache.iceberg.exceptions.NoSuchTableException( + e, "Cannot find table for %s.", ident); + } + + // throwing an iceberg NoSuchTableException because the Spark one is typed and cant be thrown + // from this interface + throw new org.apache.iceberg.exceptions.NoSuchTableException( + "Cannot find table for %s.", ident); + } + + private Spark3Util.CatalogAndIdentifier catalogAndIdentifier(CaseInsensitiveStringMap options) { + Preconditions.checkArgument( + options.containsKey(SparkReadOptions.PATH), "Cannot open table: path is not set"); + SparkSession spark = SparkSession.active(); + setupDefaultSparkCatalogs(spark); + String path = options.get(SparkReadOptions.PATH); + + Long snapshotId = propertyAsLong(options, SparkReadOptions.SNAPSHOT_ID); + Long asOfTimestamp = propertyAsLong(options, SparkReadOptions.AS_OF_TIMESTAMP); + String branch = options.get(SparkReadOptions.BRANCH); + String tag = options.get(SparkReadOptions.TAG); + Preconditions.checkArgument( + Stream.of(snapshotId, asOfTimestamp, branch, tag).filter(Objects::nonNull).count() <= 1, + "Can specify only one of snapshot-id (%s), as-of-timestamp (%s), branch (%s), tag (%s)", + snapshotId, + asOfTimestamp, + branch, + tag); + + String selector = null; + + if (snapshotId != null) { + selector = SNAPSHOT_ID + snapshotId; + } + + if (asOfTimestamp != null) { + selector = AT_TIMESTAMP + asOfTimestamp; + } + + if (branch != null) { + selector = BRANCH_PREFIX + branch; + } + + if (tag != null) { + selector = TAG_PREFIX + tag; + } + + CatalogManager catalogManager = spark.sessionState().catalogManager(); + + if (TABLE_CACHE.contains(path)) { + return new Spark3Util.CatalogAndIdentifier( + catalogManager.catalog(DEFAULT_CACHE_CATALOG_NAME), + Identifier.of(EMPTY_NAMESPACE, pathWithSelector(path, selector))); + } else if (path.contains("/")) { + // contains a path. Return iceberg default catalog and a PathIdentifier + return new Spark3Util.CatalogAndIdentifier( + catalogManager.catalog(DEFAULT_CATALOG_NAME), + new PathIdentifier(pathWithSelector(path, selector))); + } + + final Spark3Util.CatalogAndIdentifier catalogAndIdentifier = + Spark3Util.catalogAndIdentifier("path or identifier", spark, path); + + Identifier ident = identifierWithSelector(catalogAndIdentifier.identifier(), selector); + if (catalogAndIdentifier.catalog().name().equals("spark_catalog") + && !(catalogAndIdentifier.catalog() instanceof SparkSessionCatalog)) { + // catalog is a session catalog but does not support Iceberg. Use Iceberg instead. + return new Spark3Util.CatalogAndIdentifier( + catalogManager.catalog(DEFAULT_CATALOG_NAME), ident); + } else { + return new Spark3Util.CatalogAndIdentifier(catalogAndIdentifier.catalog(), ident); + } + } + + private String pathWithSelector(String path, String selector) { + return (selector == null) ? path : path + "#" + selector; + } + + private Identifier identifierWithSelector(Identifier ident, String selector) { + if (selector == null) { + return ident; + } else { + String[] namespace = ident.namespace(); + String[] ns = Arrays.copyOf(namespace, namespace.length + 1); + ns[namespace.length] = ident.name(); + return Identifier.of(ns, selector); + } + } + + @Override + public Identifier extractIdentifier(CaseInsensitiveStringMap options) { + return catalogAndIdentifier(options).identifier(); + } + + @Override + public String extractCatalog(CaseInsensitiveStringMap options) { + return catalogAndIdentifier(options).catalog().name(); + } + + @Override + public Optional extractTimeTravelVersion(CaseInsensitiveStringMap options) { + return Optional.ofNullable( + PropertyUtil.propertyAsString(options, SparkReadOptions.VERSION_AS_OF, null)); + } + + @Override + public Optional extractTimeTravelTimestamp(CaseInsensitiveStringMap options) { + return Optional.ofNullable( + PropertyUtil.propertyAsString(options, SparkReadOptions.TIMESTAMP_AS_OF, null)); + } + + private static Long propertyAsLong(CaseInsensitiveStringMap options, String property) { + String value = options.get(property); + if (value != null) { + return Long.parseLong(value); + } + + return null; + } + + private static void setupDefaultSparkCatalogs(SparkSession spark) { + if (spark.conf().getOption(DEFAULT_CATALOG).isEmpty()) { + ImmutableMap config = + ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "cache-enabled", "false" // the source should not use a cache + ); + spark.conf().set(DEFAULT_CATALOG, SparkCatalog.class.getName()); + config.forEach((key, value) -> spark.conf().set(DEFAULT_CATALOG + "." + key, value)); + } + + if (spark.conf().getOption(DEFAULT_CACHE_CATALOG).isEmpty()) { + spark.conf().set(DEFAULT_CACHE_CATALOG, SparkCachedTableCatalog.class.getName()); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/InternalRowWrapper.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/InternalRowWrapper.java new file mode 100644 index 000000000000..d1682b8c85c1 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/InternalRowWrapper.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.nio.ByteBuffer; +import java.util.UUID; +import java.util.function.BiFunction; +import java.util.stream.Stream; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * Class to adapt a Spark {@code InternalRow} to Iceberg {@link StructLike} for uses like {@link + * org.apache.iceberg.PartitionKey#partition(StructLike)} + */ +class InternalRowWrapper implements StructLike { + private final DataType[] types; + private final BiFunction[] getters; + private InternalRow row = null; + + @SuppressWarnings("unchecked") + InternalRowWrapper(StructType rowType, Types.StructType icebergSchema) { + this.types = Stream.of(rowType.fields()).map(StructField::dataType).toArray(DataType[]::new); + Preconditions.checkArgument( + types.length == icebergSchema.fields().size(), + "Invalid length: Spark struct type (%s) != Iceberg struct type (%s)", + types.length, + icebergSchema.fields().size()); + this.getters = new BiFunction[types.length]; + for (int i = 0; i < types.length; i++) { + getters[i] = getter(icebergSchema.fields().get(i).type(), types[i]); + } + } + + InternalRowWrapper wrap(InternalRow internalRow) { + this.row = internalRow; + return this; + } + + @Override + public int size() { + return types.length; + } + + @Override + public T get(int pos, Class javaClass) { + if (row.isNullAt(pos)) { + return null; + } else if (getters[pos] != null) { + return javaClass.cast(getters[pos].apply(row, pos)); + } + + return javaClass.cast(row.get(pos, types[pos])); + } + + @Override + public void set(int pos, T value) { + row.update(pos, value); + } + + private static BiFunction getter(Type icebergType, DataType type) { + if (type instanceof StringType) { + // Spark represents UUIDs as strings + if (Type.TypeID.UUID == icebergType.typeId()) { + return (row, pos) -> UUID.fromString(row.getUTF8String(pos).toString()); + } + + return (row, pos) -> row.getUTF8String(pos).toString(); + } else if (type instanceof DecimalType) { + DecimalType decimal = (DecimalType) type; + return (row, pos) -> + row.getDecimal(pos, decimal.precision(), decimal.scale()).toJavaBigDecimal(); + } else if (type instanceof BinaryType) { + return (row, pos) -> ByteBuffer.wrap(row.getBinary(pos)); + } else if (type instanceof StructType) { + StructType structType = (StructType) type; + InternalRowWrapper nestedWrapper = + new InternalRowWrapper(structType, icebergType.asStructType()); + return (row, pos) -> nestedWrapper.wrap(row.getStruct(pos, structType.size())); + } + + return null; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/PositionDeletesRowReader.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/PositionDeletesRowReader.java new file mode 100644 index 000000000000..1a894df29166 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/PositionDeletesRowReader.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.PositionDeletesScanTask; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.ExpressionUtil; +import org.apache.iceberg.io.CloseableIterator; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.primitives.Ints; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.rdd.InputFileBlockHolder; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +class PositionDeletesRowReader extends BaseRowReader + implements PartitionReader { + + private static final Logger LOG = LoggerFactory.getLogger(PositionDeletesRowReader.class); + + PositionDeletesRowReader(SparkInputPartition partition) { + this( + partition.table(), + partition.taskGroup(), + SnapshotUtil.schemaFor(partition.table(), partition.branch()), + partition.expectedSchema(), + partition.isCaseSensitive()); + } + + PositionDeletesRowReader( + Table table, + ScanTaskGroup taskGroup, + Schema tableSchema, + Schema expectedSchema, + boolean caseSensitive) { + + super(table, taskGroup, tableSchema, expectedSchema, caseSensitive); + + int numSplits = taskGroup.tasks().size(); + LOG.debug("Reading {} position delete file split(s) for table {}", numSplits, table.name()); + } + + @Override + protected Stream> referencedFiles(PositionDeletesScanTask task) { + return Stream.of(task.file()); + } + + @Override + protected CloseableIterator open(PositionDeletesScanTask task) { + String filePath = task.file().location(); + LOG.debug("Opening position delete file {}", filePath); + + // update the current file for Spark's filename() function + InputFileBlockHolder.set(filePath, task.start(), task.length()); + + InputFile inputFile = getInputFile(task.file().location()); + Preconditions.checkNotNull(inputFile, "Could not find InputFile associated with %s", task); + + // select out constant fields when pushing down filter to row reader + Map idToConstant = constantsMap(task, expectedSchema()); + Set nonConstantFieldIds = nonConstantFieldIds(idToConstant); + Expression residualWithoutConstants = + ExpressionUtil.extractByIdInclusive( + task.residual(), expectedSchema(), caseSensitive(), Ints.toArray(nonConstantFieldIds)); + + return newIterable( + inputFile, + task.file().format(), + task.start(), + task.length(), + residualWithoutConstants, + expectedSchema(), + idToConstant) + .iterator(); + } + + private Set nonConstantFieldIds(Map idToConstant) { + Set fields = expectedSchema().idToName().keySet(); + return fields.stream() + .filter(id -> expectedSchema().findField(id).type().isPrimitiveType()) + .filter(id -> !idToConstant.containsKey(id)) + .collect(Collectors.toSet()); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/RowDataReader.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/RowDataReader.java new file mode 100644 index 000000000000..f24602fd5583 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/RowDataReader.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Map; +import java.util.stream.Stream; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.DataTask; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.CloseableIterator; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.source.metrics.TaskNumDeletes; +import org.apache.iceberg.spark.source.metrics.TaskNumSplits; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.rdd.InputFileBlockHolder; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.metric.CustomTaskMetric; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +class RowDataReader extends BaseRowReader implements PartitionReader { + private static final Logger LOG = LoggerFactory.getLogger(RowDataReader.class); + + private final long numSplits; + + RowDataReader(SparkInputPartition partition) { + this( + partition.table(), + partition.taskGroup(), + SnapshotUtil.schemaFor(partition.table(), partition.branch()), + partition.expectedSchema(), + partition.isCaseSensitive()); + } + + RowDataReader( + Table table, + ScanTaskGroup taskGroup, + Schema tableSchema, + Schema expectedSchema, + boolean caseSensitive) { + + super(table, taskGroup, tableSchema, expectedSchema, caseSensitive); + + numSplits = taskGroup.tasks().size(); + LOG.debug("Reading {} file split(s) for table {}", numSplits, table.name()); + } + + @Override + public CustomTaskMetric[] currentMetricsValues() { + return new CustomTaskMetric[] { + new TaskNumSplits(numSplits), new TaskNumDeletes(counter().get()) + }; + } + + @Override + protected Stream> referencedFiles(FileScanTask task) { + return Stream.concat(Stream.of(task.file()), task.deletes().stream()); + } + + @Override + protected CloseableIterator open(FileScanTask task) { + String filePath = task.file().location(); + LOG.debug("Opening data file {}", filePath); + SparkDeleteFilter deleteFilter = + new SparkDeleteFilter(filePath, task.deletes(), counter(), true); + + // schema or rows returned by readers + Schema requiredSchema = deleteFilter.requiredSchema(); + Map idToConstant = constantsMap(task, requiredSchema); + + // update the current file for Spark's filename() function + InputFileBlockHolder.set(filePath, task.start(), task.length()); + + return deleteFilter.filter(open(task, requiredSchema, idToConstant)).iterator(); + } + + protected CloseableIterable open( + FileScanTask task, Schema readSchema, Map idToConstant) { + if (task.isDataTask()) { + return newDataIterable(task.asDataTask(), readSchema); + } else { + InputFile inputFile = getInputFile(task.file().location()); + Preconditions.checkNotNull( + inputFile, "Could not find InputFile associated with FileScanTask"); + return newIterable( + inputFile, + task.file().format(), + task.start(), + task.length(), + task.residual(), + readSchema, + idToConstant); + } + } + + private CloseableIterable newDataIterable(DataTask task, Schema readSchema) { + StructInternalRow row = new StructInternalRow(readSchema.asStruct()); + return CloseableIterable.transform(task.asDataTask().rows(), row::setStruct); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SerializableTableWithSize.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SerializableTableWithSize.java new file mode 100644 index 000000000000..f6913fb9d00d --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SerializableTableWithSize.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.BaseMetadataTable; +import org.apache.iceberg.SerializableTable; +import org.apache.iceberg.Table; +import org.apache.iceberg.spark.SparkExecutorCache; +import org.apache.spark.util.KnownSizeEstimation; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This class provides a serializable table with a known size estimate. Spark calls its + * SizeEstimator class when broadcasting variables and this can be an expensive operation, so + * providing a known size estimate allows that operation to be skipped. + * + *

This class also implements AutoCloseable to avoid leaking resources upon broadcasting. + * Broadcast variables are destroyed and cleaned up on the driver and executors once they are + * garbage collected on the driver. The implementation ensures only resources used by copies of the + * main table are released. + */ +public class SerializableTableWithSize extends SerializableTable + implements KnownSizeEstimation, AutoCloseable { + + private static final Logger LOG = LoggerFactory.getLogger(SerializableTableWithSize.class); + private static final long SIZE_ESTIMATE = 32_768L; + + private final transient Object serializationMarker; + + protected SerializableTableWithSize(Table table) { + super(table); + this.serializationMarker = new Object(); + } + + @Override + public long estimatedSize() { + return SIZE_ESTIMATE; + } + + public static Table copyOf(Table table) { + if (table instanceof BaseMetadataTable) { + return new SerializableMetadataTableWithSize((BaseMetadataTable) table); + } else { + return new SerializableTableWithSize(table); + } + } + + @Override + public void close() throws Exception { + if (serializationMarker == null) { + LOG.info("Releasing resources"); + io().close(); + } + invalidateCache(name()); + } + + public static class SerializableMetadataTableWithSize extends SerializableMetadataTable + implements KnownSizeEstimation, AutoCloseable { + + private static final Logger LOG = + LoggerFactory.getLogger(SerializableMetadataTableWithSize.class); + + private final transient Object serializationMarker; + + protected SerializableMetadataTableWithSize(BaseMetadataTable metadataTable) { + super(metadataTable); + this.serializationMarker = new Object(); + } + + @Override + public long estimatedSize() { + return SIZE_ESTIMATE; + } + + @Override + public void close() throws Exception { + if (serializationMarker == null) { + LOG.info("Releasing resources"); + io().close(); + } + invalidateCache(name()); + } + } + + private static void invalidateCache(String name) { + SparkExecutorCache cache = SparkExecutorCache.get(); + if (cache != null) { + cache.invalidate(name); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkAppenderFactory.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkAppenderFactory.java new file mode 100644 index 000000000000..c822ed743f85 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkAppenderFactory.java @@ -0,0 +1,328 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Map; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.MetricsConfig; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.deletes.EqualityDeleteWriter; +import org.apache.iceberg.deletes.PositionDeleteWriter; +import org.apache.iceberg.encryption.EncryptedOutputFile; +import org.apache.iceberg.encryption.EncryptionUtil; +import org.apache.iceberg.exceptions.RuntimeIOException; +import org.apache.iceberg.io.DataWriter; +import org.apache.iceberg.io.DeleteSchemaUtil; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.io.FileAppenderFactory; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.orc.ORC; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.SparkAvroWriter; +import org.apache.iceberg.spark.data.SparkOrcWriter; +import org.apache.iceberg.spark.data.SparkParquetWriters; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * @deprecated since 1.7.0, will be removed in 1.8.0; use {@link SparkFileWriterFactory} instead. + */ +@Deprecated +class SparkAppenderFactory implements FileAppenderFactory { + private final Map properties; + private final Schema writeSchema; + private final StructType dsSchema; + private final PartitionSpec spec; + private final int[] equalityFieldIds; + private final Schema eqDeleteRowSchema; + private final Schema posDeleteRowSchema; + + private StructType eqDeleteSparkType = null; + private StructType posDeleteSparkType = null; + + SparkAppenderFactory( + Map properties, + Schema writeSchema, + StructType dsSchema, + PartitionSpec spec, + int[] equalityFieldIds, + Schema eqDeleteRowSchema, + Schema posDeleteRowSchema) { + this.properties = properties; + this.writeSchema = writeSchema; + this.dsSchema = dsSchema; + this.spec = spec; + this.equalityFieldIds = equalityFieldIds; + this.eqDeleteRowSchema = eqDeleteRowSchema; + this.posDeleteRowSchema = posDeleteRowSchema; + } + + static Builder builderFor(Table table, Schema writeSchema, StructType dsSchema) { + return new Builder(table, writeSchema, dsSchema); + } + + static class Builder { + private final Table table; + private final Schema writeSchema; + private final StructType dsSchema; + private PartitionSpec spec; + private int[] equalityFieldIds; + private Schema eqDeleteRowSchema; + private Schema posDeleteRowSchema; + + Builder(Table table, Schema writeSchema, StructType dsSchema) { + this.table = table; + this.spec = table.spec(); + this.writeSchema = writeSchema; + this.dsSchema = dsSchema; + } + + Builder spec(PartitionSpec newSpec) { + this.spec = newSpec; + return this; + } + + Builder equalityFieldIds(int[] newEqualityFieldIds) { + this.equalityFieldIds = newEqualityFieldIds; + return this; + } + + Builder eqDeleteRowSchema(Schema newEqDeleteRowSchema) { + this.eqDeleteRowSchema = newEqDeleteRowSchema; + return this; + } + + Builder posDelRowSchema(Schema newPosDelRowSchema) { + this.posDeleteRowSchema = newPosDelRowSchema; + return this; + } + + SparkAppenderFactory build() { + Preconditions.checkNotNull(table, "Table must not be null"); + Preconditions.checkNotNull(writeSchema, "Write Schema must not be null"); + Preconditions.checkNotNull(dsSchema, "DS Schema must not be null"); + if (equalityFieldIds != null) { + Preconditions.checkNotNull( + eqDeleteRowSchema, + "Equality Field Ids and Equality Delete Row Schema" + " must be set together"); + } + if (eqDeleteRowSchema != null) { + Preconditions.checkNotNull( + equalityFieldIds, + "Equality Field Ids and Equality Delete Row Schema" + " must be set together"); + } + + return new SparkAppenderFactory( + table.properties(), + writeSchema, + dsSchema, + spec, + equalityFieldIds, + eqDeleteRowSchema, + posDeleteRowSchema); + } + } + + private StructType lazyEqDeleteSparkType() { + if (eqDeleteSparkType == null) { + Preconditions.checkNotNull(eqDeleteRowSchema, "Equality delete row schema shouldn't be null"); + this.eqDeleteSparkType = SparkSchemaUtil.convert(eqDeleteRowSchema); + } + return eqDeleteSparkType; + } + + private StructType lazyPosDeleteSparkType() { + if (posDeleteSparkType == null) { + Preconditions.checkNotNull( + posDeleteRowSchema, "Position delete row schema shouldn't be null"); + this.posDeleteSparkType = SparkSchemaUtil.convert(posDeleteRowSchema); + } + return posDeleteSparkType; + } + + @Override + public FileAppender newAppender(OutputFile file, FileFormat fileFormat) { + return newAppender(EncryptionUtil.plainAsEncryptedOutput(file), fileFormat); + } + + @Override + public FileAppender newAppender(EncryptedOutputFile file, FileFormat fileFormat) { + MetricsConfig metricsConfig = MetricsConfig.fromProperties(properties); + try { + switch (fileFormat) { + case PARQUET: + return Parquet.write(file) + .createWriterFunc(msgType -> SparkParquetWriters.buildWriter(dsSchema, msgType)) + .setAll(properties) + .metricsConfig(metricsConfig) + .schema(writeSchema) + .overwrite() + .build(); + + case AVRO: + return Avro.write(file) + .createWriterFunc(ignored -> new SparkAvroWriter(dsSchema)) + .setAll(properties) + .schema(writeSchema) + .overwrite() + .build(); + + case ORC: + return ORC.write(file) + .createWriterFunc(SparkOrcWriter::new) + .setAll(properties) + .metricsConfig(metricsConfig) + .schema(writeSchema) + .overwrite() + .build(); + + default: + throw new UnsupportedOperationException("Cannot write unknown format: " + fileFormat); + } + } catch (IOException e) { + throw new RuntimeIOException(e); + } + } + + @Override + public DataWriter newDataWriter( + EncryptedOutputFile file, FileFormat format, StructLike partition) { + return new DataWriter<>( + newAppender(file, format), + format, + file.encryptingOutputFile().location(), + spec, + partition, + file.keyMetadata()); + } + + @Override + public EqualityDeleteWriter newEqDeleteWriter( + EncryptedOutputFile file, FileFormat format, StructLike partition) { + Preconditions.checkState( + equalityFieldIds != null && equalityFieldIds.length > 0, + "Equality field ids shouldn't be null or empty when creating equality-delete writer"); + Preconditions.checkNotNull( + eqDeleteRowSchema, + "Equality delete row schema shouldn't be null when creating equality-delete writer"); + + try { + switch (format) { + case PARQUET: + return Parquet.writeDeletes(file) + .createWriterFunc( + msgType -> SparkParquetWriters.buildWriter(lazyEqDeleteSparkType(), msgType)) + .overwrite() + .rowSchema(eqDeleteRowSchema) + .withSpec(spec) + .withPartition(partition) + .equalityFieldIds(equalityFieldIds) + .withKeyMetadata(file.keyMetadata()) + .buildEqualityWriter(); + + case AVRO: + return Avro.writeDeletes(file) + .createWriterFunc(ignored -> new SparkAvroWriter(lazyEqDeleteSparkType())) + .overwrite() + .rowSchema(eqDeleteRowSchema) + .withSpec(spec) + .withPartition(partition) + .equalityFieldIds(equalityFieldIds) + .withKeyMetadata(file.keyMetadata()) + .buildEqualityWriter(); + + case ORC: + return ORC.writeDeletes(file) + .createWriterFunc(SparkOrcWriter::new) + .overwrite() + .rowSchema(eqDeleteRowSchema) + .withSpec(spec) + .withPartition(partition) + .equalityFieldIds(equalityFieldIds) + .withKeyMetadata(file.keyMetadata()) + .buildEqualityWriter(); + + default: + throw new UnsupportedOperationException( + "Cannot write equality-deletes for unsupported file format: " + format); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to create new equality delete writer", e); + } + } + + @Override + public PositionDeleteWriter newPosDeleteWriter( + EncryptedOutputFile file, FileFormat format, StructLike partition) { + try { + switch (format) { + case PARQUET: + StructType sparkPosDeleteSchema = + SparkSchemaUtil.convert(DeleteSchemaUtil.posDeleteSchema(posDeleteRowSchema)); + return Parquet.writeDeletes(file) + .createWriterFunc( + msgType -> SparkParquetWriters.buildWriter(sparkPosDeleteSchema, msgType)) + .overwrite() + .rowSchema(posDeleteRowSchema) + .withSpec(spec) + .withPartition(partition) + .withKeyMetadata(file.keyMetadata()) + .transformPaths(path -> UTF8String.fromString(path.toString())) + .buildPositionWriter(); + + case AVRO: + return Avro.writeDeletes(file) + .createWriterFunc(ignored -> new SparkAvroWriter(lazyPosDeleteSparkType())) + .overwrite() + .rowSchema(posDeleteRowSchema) + .withSpec(spec) + .withPartition(partition) + .withKeyMetadata(file.keyMetadata()) + .buildPositionWriter(); + + case ORC: + return ORC.writeDeletes(file) + .createWriterFunc(SparkOrcWriter::new) + .overwrite() + .rowSchema(posDeleteRowSchema) + .withSpec(spec) + .withPartition(partition) + .withKeyMetadata(file.keyMetadata()) + .transformPaths(path -> UTF8String.fromString(path.toString())) + .buildPositionWriter(); + + default: + throw new UnsupportedOperationException( + "Cannot write pos-deletes for unsupported file format: " + format); + } + + } catch (IOException e) { + throw new UncheckedIOException("Failed to create new equality delete writer", e); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java new file mode 100644 index 000000000000..fd6783f3e1f7 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.List; +import java.util.Objects; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.ScanTask; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SchemaParser; +import org.apache.iceberg.Table; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.iceberg.spark.SparkUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.connector.read.Batch; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.PartitionReaderFactory; + +class SparkBatch implements Batch { + + private final JavaSparkContext sparkContext; + private final Table table; + private final String branch; + private final SparkReadConf readConf; + private final Types.StructType groupingKeyType; + private final List> taskGroups; + private final Schema expectedSchema; + private final boolean caseSensitive; + private final boolean localityEnabled; + private final boolean executorCacheLocalityEnabled; + private final int scanHashCode; + + SparkBatch( + JavaSparkContext sparkContext, + Table table, + SparkReadConf readConf, + Types.StructType groupingKeyType, + List> taskGroups, + Schema expectedSchema, + int scanHashCode) { + this.sparkContext = sparkContext; + this.table = table; + this.branch = readConf.branch(); + this.readConf = readConf; + this.groupingKeyType = groupingKeyType; + this.taskGroups = taskGroups; + this.expectedSchema = expectedSchema; + this.caseSensitive = readConf.caseSensitive(); + this.localityEnabled = readConf.localityEnabled(); + this.executorCacheLocalityEnabled = readConf.executorCacheLocalityEnabled(); + this.scanHashCode = scanHashCode; + } + + @Override + public InputPartition[] planInputPartitions() { + // broadcast the table metadata as input partitions will be sent to executors + Broadcast

tableBroadcast = + sparkContext.broadcast(SerializableTableWithSize.copyOf(table)); + String expectedSchemaString = SchemaParser.toJson(expectedSchema); + String[][] locations = computePreferredLocations(); + + InputPartition[] partitions = new InputPartition[taskGroups.size()]; + + for (int index = 0; index < taskGroups.size(); index++) { + partitions[index] = + new SparkInputPartition( + groupingKeyType, + taskGroups.get(index), + tableBroadcast, + branch, + expectedSchemaString, + caseSensitive, + locations != null ? locations[index] : SparkPlanningUtil.NO_LOCATION_PREFERENCE); + } + + return partitions; + } + + private String[][] computePreferredLocations() { + if (localityEnabled) { + return SparkPlanningUtil.fetchBlockLocations(table.io(), taskGroups); + + } else if (executorCacheLocalityEnabled) { + List executorLocations = SparkUtil.executorLocations(); + if (!executorLocations.isEmpty()) { + return SparkPlanningUtil.assignExecutors(taskGroups, executorLocations); + } + } + + return null; + } + + @Override + public PartitionReaderFactory createReaderFactory() { + if (useParquetBatchReads()) { + int batchSize = readConf.parquetBatchSize(); + return new SparkColumnarReaderFactory(batchSize); + + } else if (useOrcBatchReads()) { + int batchSize = readConf.orcBatchSize(); + return new SparkColumnarReaderFactory(batchSize); + + } else { + return new SparkRowReaderFactory(); + } + } + + // conditions for using Parquet batch reads: + // - Parquet vectorization is enabled + // - only primitives or metadata columns are projected + // - all tasks are of FileScanTask type and read only Parquet files + private boolean useParquetBatchReads() { + return readConf.parquetVectorizationEnabled() + && expectedSchema.columns().stream().allMatch(this::supportsParquetBatchReads) + && taskGroups.stream().allMatch(this::supportsParquetBatchReads); + } + + private boolean supportsParquetBatchReads(ScanTask task) { + if (task instanceof ScanTaskGroup) { + ScanTaskGroup taskGroup = (ScanTaskGroup) task; + return taskGroup.tasks().stream().allMatch(this::supportsParquetBatchReads); + + } else if (task.isFileScanTask() && !task.isDataTask()) { + FileScanTask fileScanTask = task.asFileScanTask(); + return fileScanTask.file().format() == FileFormat.PARQUET; + + } else { + return false; + } + } + + private boolean supportsParquetBatchReads(Types.NestedField field) { + return field.type().isPrimitiveType() || MetadataColumns.isMetadataColumn(field.fieldId()); + } + + // conditions for using ORC batch reads: + // - ORC vectorization is enabled + // - all tasks are of type FileScanTask and read only ORC files with no delete files + private boolean useOrcBatchReads() { + return readConf.orcVectorizationEnabled() + && taskGroups.stream().allMatch(this::supportsOrcBatchReads); + } + + private boolean supportsOrcBatchReads(ScanTask task) { + if (task instanceof ScanTaskGroup) { + ScanTaskGroup taskGroup = (ScanTaskGroup) task; + return taskGroup.tasks().stream().allMatch(this::supportsOrcBatchReads); + + } else if (task.isFileScanTask() && !task.isDataTask()) { + FileScanTask fileScanTask = task.asFileScanTask(); + return fileScanTask.file().format() == FileFormat.ORC && fileScanTask.deletes().isEmpty(); + + } else { + return false; + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + SparkBatch that = (SparkBatch) o; + return table.name().equals(that.table.name()) && scanHashCode == that.scanHashCode; + } + + @Override + public int hashCode() { + return Objects.hash(table.name(), scanHashCode); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java new file mode 100644 index 000000000000..18e483f23fc6 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.PartitionField; +import org.apache.iceberg.PartitionScanTask; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Scan; +import org.apache.iceberg.ScanTask; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.Binder; +import org.apache.iceberg.expressions.Evaluator; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.ExpressionUtil; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.expressions.Projections; +import org.apache.iceberg.metrics.ScanReport; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkV2Filters; +import org.apache.iceberg.util.ContentFileUtil; +import org.apache.iceberg.util.DeleteFileSet; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.filter.Predicate; +import org.apache.spark.sql.connector.read.Statistics; +import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +class SparkBatchQueryScan extends SparkPartitioningAwareScan + implements SupportsRuntimeV2Filtering { + + private static final Logger LOG = LoggerFactory.getLogger(SparkBatchQueryScan.class); + + private final Long snapshotId; + private final Long startSnapshotId; + private final Long endSnapshotId; + private final Long asOfTimestamp; + private final String tag; + private final List runtimeFilterExpressions; + + SparkBatchQueryScan( + SparkSession spark, + Table table, + Scan> scan, + SparkReadConf readConf, + Schema expectedSchema, + List filters, + Supplier scanReportSupplier) { + super(spark, table, scan, readConf, expectedSchema, filters, scanReportSupplier); + + this.snapshotId = readConf.snapshotId(); + this.startSnapshotId = readConf.startSnapshotId(); + this.endSnapshotId = readConf.endSnapshotId(); + this.asOfTimestamp = readConf.asOfTimestamp(); + this.tag = readConf.tag(); + this.runtimeFilterExpressions = Lists.newArrayList(); + } + + Long snapshotId() { + return snapshotId; + } + + @Override + protected Class taskJavaClass() { + return PartitionScanTask.class; + } + + @Override + public NamedReference[] filterAttributes() { + Set partitionFieldSourceIds = Sets.newHashSet(); + + for (PartitionSpec spec : specs()) { + for (PartitionField field : spec.fields()) { + partitionFieldSourceIds.add(field.sourceId()); + } + } + + Map quotedNameById = SparkSchemaUtil.indexQuotedNameById(expectedSchema()); + + // the optimizer will look for an equality condition with filter attributes in a join + // as the scan has been already planned, filtering can only be done on projected attributes + // that's why only partition source fields that are part of the read schema can be reported + + return partitionFieldSourceIds.stream() + .filter(fieldId -> expectedSchema().findField(fieldId) != null) + .map(fieldId -> Spark3Util.toNamedReference(quotedNameById.get(fieldId))) + .toArray(NamedReference[]::new); + } + + @Override + public void filter(Predicate[] predicates) { + Expression runtimeFilterExpr = convertRuntimeFilters(predicates); + + if (runtimeFilterExpr != Expressions.alwaysTrue()) { + Map evaluatorsBySpecId = Maps.newHashMap(); + + for (PartitionSpec spec : specs()) { + Expression inclusiveExpr = + Projections.inclusive(spec, caseSensitive()).project(runtimeFilterExpr); + Evaluator inclusive = new Evaluator(spec.partitionType(), inclusiveExpr); + evaluatorsBySpecId.put(spec.specId(), inclusive); + } + + List filteredTasks = + tasks().stream() + .filter( + task -> { + Evaluator evaluator = evaluatorsBySpecId.get(task.spec().specId()); + return evaluator.eval(task.partition()); + }) + .collect(Collectors.toList()); + + LOG.info( + "{} of {} task(s) for table {} matched runtime filter {}", + filteredTasks.size(), + tasks().size(), + table().name(), + ExpressionUtil.toSanitizedString(runtimeFilterExpr)); + + // don't invalidate tasks if the runtime filter had no effect to avoid planning splits again + if (filteredTasks.size() < tasks().size()) { + resetTasks(filteredTasks); + } + + // save the evaluated filter for equals/hashCode + runtimeFilterExpressions.add(runtimeFilterExpr); + } + } + + protected Map rewritableDeletes() { + Map rewritableDeletes = Maps.newHashMap(); + + for (ScanTask task : tasks()) { + FileScanTask fileScanTask = task.asFileScanTask(); + for (DeleteFile deleteFile : fileScanTask.deletes()) { + if (ContentFileUtil.isFileScoped(deleteFile)) { + rewritableDeletes + .computeIfAbsent(fileScanTask.file().location(), ignored -> DeleteFileSet.create()) + .add(deleteFile); + } + } + } + + return rewritableDeletes; + } + + // at this moment, Spark can only pass IN filters for a single attribute + // if there are multiple filter attributes, Spark will pass two separate IN filters + private Expression convertRuntimeFilters(Predicate[] predicates) { + Expression runtimeFilterExpr = Expressions.alwaysTrue(); + + for (Predicate predicate : predicates) { + Expression expr = SparkV2Filters.convert(predicate); + if (expr != null) { + try { + Binder.bind(expectedSchema().asStruct(), expr, caseSensitive()); + runtimeFilterExpr = Expressions.and(runtimeFilterExpr, expr); + } catch (ValidationException e) { + LOG.warn("Failed to bind {} to expected schema, skipping runtime filter", expr, e); + } + } else { + LOG.warn("Unsupported runtime filter {}", predicate); + } + } + + return runtimeFilterExpr; + } + + @Override + public Statistics estimateStatistics() { + if (scan() == null) { + return estimateStatistics(null); + + } else if (snapshotId != null) { + Snapshot snapshot = table().snapshot(snapshotId); + return estimateStatistics(snapshot); + + } else if (asOfTimestamp != null) { + long snapshotIdAsOfTime = SnapshotUtil.snapshotIdAsOfTime(table(), asOfTimestamp); + Snapshot snapshot = table().snapshot(snapshotIdAsOfTime); + return estimateStatistics(snapshot); + + } else if (branch() != null) { + Snapshot snapshot = table().snapshot(branch()); + return estimateStatistics(snapshot); + + } else if (tag != null) { + Snapshot snapshot = table().snapshot(tag); + return estimateStatistics(snapshot); + + } else { + Snapshot snapshot = table().currentSnapshot(); + return estimateStatistics(snapshot); + } + } + + @Override + @SuppressWarnings("checkstyle:CyclomaticComplexity") + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + SparkBatchQueryScan that = (SparkBatchQueryScan) o; + return table().name().equals(that.table().name()) + && Objects.equals(branch(), that.branch()) + && readSchema().equals(that.readSchema()) // compare Spark schemas to ignore field ids + && filterExpressions().toString().equals(that.filterExpressions().toString()) + && runtimeFilterExpressions.toString().equals(that.runtimeFilterExpressions.toString()) + && Objects.equals(snapshotId, that.snapshotId) + && Objects.equals(startSnapshotId, that.startSnapshotId) + && Objects.equals(endSnapshotId, that.endSnapshotId) + && Objects.equals(asOfTimestamp, that.asOfTimestamp) + && Objects.equals(tag, that.tag); + } + + @Override + public int hashCode() { + return Objects.hash( + table().name(), + branch(), + readSchema(), + filterExpressions().toString(), + runtimeFilterExpressions.toString(), + snapshotId, + startSnapshotId, + endSnapshotId, + asOfTimestamp, + tag); + } + + @Override + public String toString() { + return String.format( + "IcebergScan(table=%s, branch=%s, type=%s, filters=%s, runtimeFilters=%s, caseSensitive=%s)", + table(), + branch(), + expectedSchema().asStruct(), + filterExpressions(), + runtimeFilterExpressions, + caseSensitive()); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkChangelogScan.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkChangelogScan.java new file mode 100644 index 000000000000..71b53d70262f --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkChangelogScan.java @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import org.apache.iceberg.ChangelogScanTask; +import org.apache.iceberg.IncrementalChangelogScan; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.read.Batch; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.Statistics; +import org.apache.spark.sql.connector.read.SupportsReportStatistics; +import org.apache.spark.sql.types.StructType; + +class SparkChangelogScan implements Scan, SupportsReportStatistics { + + private static final Types.StructType EMPTY_GROUPING_KEY_TYPE = Types.StructType.of(); + + private final JavaSparkContext sparkContext; + private final Table table; + private final IncrementalChangelogScan scan; + private final SparkReadConf readConf; + private final Schema expectedSchema; + private final List filters; + private final Long startSnapshotId; + private final Long endSnapshotId; + + // lazy variables + private List> taskGroups = null; + private StructType expectedSparkType = null; + + SparkChangelogScan( + SparkSession spark, + Table table, + IncrementalChangelogScan scan, + SparkReadConf readConf, + Schema expectedSchema, + List filters, + boolean emptyScan) { + + SparkSchemaUtil.validateMetadataColumnReferences(table.schema(), expectedSchema); + + this.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + this.table = table; + this.scan = scan; + this.readConf = readConf; + this.expectedSchema = expectedSchema; + this.filters = filters != null ? filters : Collections.emptyList(); + this.startSnapshotId = readConf.startSnapshotId(); + this.endSnapshotId = readConf.endSnapshotId(); + if (emptyScan) { + this.taskGroups = Collections.emptyList(); + } + } + + @Override + public Statistics estimateStatistics() { + long rowsCount = taskGroups().stream().mapToLong(ScanTaskGroup::estimatedRowsCount).sum(); + long sizeInBytes = SparkSchemaUtil.estimateSize(readSchema(), rowsCount); + return new Stats(sizeInBytes, rowsCount, Collections.emptyMap()); + } + + @Override + public StructType readSchema() { + if (expectedSparkType == null) { + this.expectedSparkType = SparkSchemaUtil.convert(expectedSchema); + } + + return expectedSparkType; + } + + @Override + public Batch toBatch() { + return new SparkBatch( + sparkContext, + table, + readConf, + EMPTY_GROUPING_KEY_TYPE, + taskGroups(), + expectedSchema, + hashCode()); + } + + private List> taskGroups() { + if (taskGroups == null) { + try (CloseableIterable> groups = scan.planTasks()) { + this.taskGroups = Lists.newArrayList(groups); + } catch (IOException e) { + throw new UncheckedIOException("Failed to close changelog scan: " + scan, e); + } + } + + return taskGroups; + } + + @Override + public String description() { + return String.format( + "%s [fromSnapshotId=%d, toSnapshotId=%d, filters=%s]", + table, startSnapshotId, endSnapshotId, Spark3Util.describe(filters)); + } + + @Override + public String toString() { + return String.format( + "IcebergChangelogScan(table=%s, type=%s, fromSnapshotId=%d, toSnapshotId=%d, filters=%s)", + table, + expectedSchema.asStruct(), + startSnapshotId, + endSnapshotId, + Spark3Util.describe(filters)); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + SparkChangelogScan that = (SparkChangelogScan) o; + return table.name().equals(that.table.name()) + && readSchema().equals(that.readSchema()) // compare Spark schemas to ignore field IDs + && filters.toString().equals(that.filters.toString()) + && Objects.equals(startSnapshotId, that.startSnapshotId) + && Objects.equals(endSnapshotId, that.endSnapshotId); + } + + @Override + public int hashCode() { + return Objects.hash( + table.name(), readSchema(), filters.toString(), startSnapshotId, endSnapshotId); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkChangelogTable.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkChangelogTable.java new file mode 100644 index 000000000000..61611a08c4d4 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkChangelogTable.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Set; +import org.apache.iceberg.ChangelogUtil; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Partitioning; +import org.apache.iceberg.Schema; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.catalog.MetadataColumn; +import org.apache.spark.sql.connector.catalog.SupportsMetadataColumns; +import org.apache.spark.sql.connector.catalog.SupportsRead; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCapability; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +public class SparkChangelogTable implements Table, SupportsRead, SupportsMetadataColumns { + + public static final String TABLE_NAME = "changes"; + + private static final Set CAPABILITIES = + ImmutableSet.of(TableCapability.BATCH_READ); + + private final org.apache.iceberg.Table icebergTable; + private final boolean refreshEagerly; + + private SparkSession lazySpark = null; + private StructType lazyTableSparkType = null; + private Schema lazyChangelogSchema = null; + + public SparkChangelogTable(org.apache.iceberg.Table icebergTable, boolean refreshEagerly) { + this.icebergTable = icebergTable; + this.refreshEagerly = refreshEagerly; + } + + @Override + public String name() { + return icebergTable.name() + "." + TABLE_NAME; + } + + @Override + public StructType schema() { + if (lazyTableSparkType == null) { + this.lazyTableSparkType = SparkSchemaUtil.convert(changelogSchema()); + } + + return lazyTableSparkType; + } + + @Override + public Set capabilities() { + return CAPABILITIES; + } + + @Override + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { + if (refreshEagerly) { + icebergTable.refresh(); + } + + return new SparkScanBuilder(spark(), icebergTable, changelogSchema(), options) { + @Override + public Scan build() { + return buildChangelogScan(); + } + }; + } + + private Schema changelogSchema() { + if (lazyChangelogSchema == null) { + this.lazyChangelogSchema = ChangelogUtil.changelogSchema(icebergTable.schema()); + } + + return lazyChangelogSchema; + } + + private SparkSession spark() { + if (lazySpark == null) { + this.lazySpark = SparkSession.active(); + } + + return lazySpark; + } + + @Override + public MetadataColumn[] metadataColumns() { + DataType sparkPartitionType = SparkSchemaUtil.convert(Partitioning.partitionType(icebergTable)); + return new MetadataColumn[] { + new SparkMetadataColumn(MetadataColumns.SPEC_ID.name(), DataTypes.IntegerType, false), + new SparkMetadataColumn(MetadataColumns.PARTITION_COLUMN_NAME, sparkPartitionType, true), + new SparkMetadataColumn(MetadataColumns.FILE_PATH.name(), DataTypes.StringType, false), + new SparkMetadataColumn(MetadataColumns.ROW_POSITION.name(), DataTypes.LongType, false), + new SparkMetadataColumn(MetadataColumns.IS_DELETED.name(), DataTypes.BooleanType, false) + }; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkCleanupUtil.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkCleanupUtil.java new file mode 100644 index 000000000000..5f343128161d --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkCleanupUtil.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.exceptions.NotFoundException; +import org.apache.iceberg.io.BulkDeletionFailureException; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.SupportsBulkOperations; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.util.Tasks; +import org.apache.iceberg.util.ThreadPools; +import org.apache.spark.TaskContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** A utility for cleaning up written but not committed files. */ +class SparkCleanupUtil { + + private static final Logger LOG = LoggerFactory.getLogger(SparkCleanupUtil.class); + + private static final int DELETE_NUM_RETRIES = 3; + private static final int DELETE_MIN_RETRY_WAIT_MS = 100; // 100 ms + private static final int DELETE_MAX_RETRY_WAIT_MS = 30 * 1000; // 30 seconds + private static final int DELETE_TOTAL_RETRY_TIME_MS = 2 * 60 * 1000; // 2 minutes + + private SparkCleanupUtil() {} + + /** + * Attempts to delete as many files produced by a task as possible. + * + *

Note this method will log Spark task info and is supposed to be called only on executors. + * Use {@link #deleteFiles(String, FileIO, List)} to delete files on the driver. + * + * @param io a {@link FileIO} instance used for deleting files + * @param files a list of files to delete + */ + public static void deleteTaskFiles(FileIO io, List> files) { + deleteFiles(taskInfo(), io, files); + } + + // the format matches what Spark uses for internal logging + private static String taskInfo() { + TaskContext taskContext = TaskContext.get(); + if (taskContext == null) { + return "unknown task"; + } else { + return String.format( + "partition %d (task %d, attempt %d, stage %d.%d)", + taskContext.partitionId(), + taskContext.taskAttemptId(), + taskContext.attemptNumber(), + taskContext.stageId(), + taskContext.stageAttemptNumber()); + } + } + + /** + * Attempts to delete as many given files as possible. + * + * @param context a helpful description of the operation invoking this method + * @param io a {@link FileIO} instance used for deleting files + * @param files a list of files to delete + */ + public static void deleteFiles(String context, FileIO io, List> files) { + List paths = Lists.transform(files, ContentFile::location); + deletePaths(context, io, paths); + } + + private static void deletePaths(String context, FileIO io, List paths) { + if (io instanceof SupportsBulkOperations) { + SupportsBulkOperations bulkIO = (SupportsBulkOperations) io; + bulkDelete(context, bulkIO, paths); + } else { + delete(context, io, paths); + } + } + + private static void bulkDelete(String context, SupportsBulkOperations io, List paths) { + try { + io.deleteFiles(paths); + LOG.info("Deleted {} file(s) using bulk deletes ({})", paths.size(), context); + + } catch (BulkDeletionFailureException e) { + int deletedFilesCount = paths.size() - e.numberFailedObjects(); + LOG.warn( + "Deleted only {} of {} file(s) using bulk deletes ({})", + deletedFilesCount, + paths.size(), + context); + } + } + + private static void delete(String context, FileIO io, List paths) { + AtomicInteger deletedFilesCount = new AtomicInteger(0); + + Tasks.foreach(paths) + .executeWith(ThreadPools.getWorkerPool()) + .stopRetryOn(NotFoundException.class) + .suppressFailureWhenFinished() + .onFailure((path, exc) -> LOG.warn("Failed to delete {} ({})", path, context, exc)) + .retry(DELETE_NUM_RETRIES) + .exponentialBackoff( + DELETE_MIN_RETRY_WAIT_MS, + DELETE_MAX_RETRY_WAIT_MS, + DELETE_TOTAL_RETRY_TIME_MS, + 2 /* exponential */) + .run( + path -> { + io.deleteFile(path); + deletedFilesCount.incrementAndGet(); + }); + + if (deletedFilesCount.get() < paths.size()) { + LOG.warn("Deleted only {} of {} file(s) ({})", deletedFilesCount, paths.size(), context); + } else { + LOG.info("Deleted {} file(s) ({})", paths.size(), context); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkColumnStatistics.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkColumnStatistics.java new file mode 100644 index 000000000000..faaff3631d7c --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkColumnStatistics.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Optional; +import java.util.OptionalLong; +import org.apache.spark.sql.connector.read.colstats.ColumnStatistics; +import org.apache.spark.sql.connector.read.colstats.Histogram; + +class SparkColumnStatistics implements ColumnStatistics { + + private final OptionalLong distinctCount; + private final Optional min; + private final Optional max; + private final OptionalLong nullCount; + private final OptionalLong avgLen; + private final OptionalLong maxLen; + private final Optional histogram; + + SparkColumnStatistics( + Long distinctCount, + Object min, + Object max, + Long nullCount, + Long avgLen, + Long maxLen, + Histogram histogram) { + this.distinctCount = + (distinctCount == null) ? OptionalLong.empty() : OptionalLong.of(distinctCount); + this.min = Optional.ofNullable(min); + this.max = Optional.ofNullable(max); + this.nullCount = (nullCount == null) ? OptionalLong.empty() : OptionalLong.of(nullCount); + this.avgLen = (avgLen == null) ? OptionalLong.empty() : OptionalLong.of(avgLen); + this.maxLen = (maxLen == null) ? OptionalLong.empty() : OptionalLong.of(maxLen); + this.histogram = Optional.ofNullable(histogram); + } + + @Override + public OptionalLong distinctCount() { + return distinctCount; + } + + @Override + public Optional min() { + return min; + } + + @Override + public Optional max() { + return max; + } + + @Override + public OptionalLong nullCount() { + return nullCount; + } + + @Override + public OptionalLong avgLen() { + return avgLen; + } + + @Override + public OptionalLong maxLen() { + return maxLen; + } + + @Override + public Optional histogram() { + return histogram; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkColumnarReaderFactory.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkColumnarReaderFactory.java new file mode 100644 index 000000000000..655e20a50e11 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkColumnarReaderFactory.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.sql.connector.read.PartitionReaderFactory; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +class SparkColumnarReaderFactory implements PartitionReaderFactory { + private final int batchSize; + + SparkColumnarReaderFactory(int batchSize) { + Preconditions.checkArgument(batchSize > 1, "Batch size must be > 1"); + this.batchSize = batchSize; + } + + @Override + public PartitionReader createReader(InputPartition inputPartition) { + throw new UnsupportedOperationException("Row-based reads are not supported"); + } + + @Override + public PartitionReader createColumnarReader(InputPartition inputPartition) { + Preconditions.checkArgument( + inputPartition instanceof SparkInputPartition, + "Unknown input partition type: %s", + inputPartition.getClass().getName()); + + SparkInputPartition partition = (SparkInputPartition) inputPartition; + + if (partition.allTasksOfType(FileScanTask.class)) { + return new BatchDataReader(partition, batchSize); + + } else { + throw new UnsupportedOperationException( + "Unsupported task group for columnar reads: " + partition.taskGroup()); + } + } + + @Override + public boolean supportColumnarReads(InputPartition inputPartition) { + return true; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteOperation.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteOperation.java new file mode 100644 index 000000000000..4fca05345a2e --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteOperation.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE; +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPDATE; + +import org.apache.iceberg.IsolationLevel; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Table; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.RowLevelOperation; +import org.apache.spark.sql.connector.write.RowLevelOperationInfo; +import org.apache.spark.sql.connector.write.WriteBuilder; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +class SparkCopyOnWriteOperation implements RowLevelOperation { + + private final SparkSession spark; + private final Table table; + private final String branch; + private final Command command; + private final IsolationLevel isolationLevel; + + // lazy vars + private ScanBuilder lazyScanBuilder; + private Scan configuredScan; + private WriteBuilder lazyWriteBuilder; + + SparkCopyOnWriteOperation( + SparkSession spark, + Table table, + String branch, + RowLevelOperationInfo info, + IsolationLevel isolationLevel) { + this.spark = spark; + this.table = table; + this.branch = branch; + this.command = info.command(); + this.isolationLevel = isolationLevel; + } + + @Override + public Command command() { + return command; + } + + @Override + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { + if (lazyScanBuilder == null) { + lazyScanBuilder = + new SparkScanBuilder(spark, table, branch, options) { + @Override + public Scan build() { + Scan scan = super.buildCopyOnWriteScan(); + SparkCopyOnWriteOperation.this.configuredScan = scan; + return scan; + } + }; + } + + return lazyScanBuilder; + } + + @Override + public WriteBuilder newWriteBuilder(LogicalWriteInfo info) { + if (lazyWriteBuilder == null) { + SparkWriteBuilder writeBuilder = new SparkWriteBuilder(spark, table, branch, info); + lazyWriteBuilder = writeBuilder.overwriteFiles(configuredScan, command, isolationLevel); + } + + return lazyWriteBuilder; + } + + @Override + public NamedReference[] requiredMetadataAttributes() { + NamedReference file = Expressions.column(MetadataColumns.FILE_PATH.name()); + NamedReference pos = Expressions.column(MetadataColumns.ROW_POSITION.name()); + + if (command == DELETE || command == UPDATE) { + return new NamedReference[] {file, pos}; + } else { + return new NamedReference[] {file}; + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteScan.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteScan.java new file mode 100644 index 000000000000..7a6025b0731a --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteScan.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import org.apache.iceberg.BatchScan; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.metrics.ScanReport; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.read.Statistics; +import org.apache.spark.sql.connector.read.SupportsRuntimeFiltering; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.sources.In; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +class SparkCopyOnWriteScan extends SparkPartitioningAwareScan + implements SupportsRuntimeFiltering { + + private static final Logger LOG = LoggerFactory.getLogger(SparkCopyOnWriteScan.class); + + private final Snapshot snapshot; + private Set filteredLocations = null; + + SparkCopyOnWriteScan( + SparkSession spark, + Table table, + SparkReadConf readConf, + Schema expectedSchema, + List filters, + Supplier scanReportSupplier) { + this(spark, table, null, null, readConf, expectedSchema, filters, scanReportSupplier); + } + + SparkCopyOnWriteScan( + SparkSession spark, + Table table, + BatchScan scan, + Snapshot snapshot, + SparkReadConf readConf, + Schema expectedSchema, + List filters, + Supplier scanReportSupplier) { + super(spark, table, scan, readConf, expectedSchema, filters, scanReportSupplier); + + this.snapshot = snapshot; + + if (scan == null) { + this.filteredLocations = Collections.emptySet(); + } + } + + Long snapshotId() { + return snapshot != null ? snapshot.snapshotId() : null; + } + + @Override + protected Class taskJavaClass() { + return FileScanTask.class; + } + + @Override + public Statistics estimateStatistics() { + return estimateStatistics(snapshot); + } + + public NamedReference[] filterAttributes() { + NamedReference file = Expressions.column(MetadataColumns.FILE_PATH.name()); + return new NamedReference[] {file}; + } + + @Override + public void filter(Filter[] filters) { + Preconditions.checkState( + Objects.equals(snapshotId(), currentSnapshotId()), + "Runtime file filtering is not possible: the table has been concurrently modified. " + + "Row-level operation scan snapshot ID: %s, current table snapshot ID: %s. " + + "If an external process modifies the table, enable table caching in the catalog. " + + "If multiple threads modify the table, use independent Spark sessions in each thread.", + snapshotId(), + currentSnapshotId()); + + for (Filter filter : filters) { + // Spark can only pass In filters at the moment + if (filter instanceof In + && ((In) filter).attribute().equalsIgnoreCase(MetadataColumns.FILE_PATH.name())) { + In in = (In) filter; + + Set fileLocations = Sets.newHashSet(); + for (Object value : in.values()) { + fileLocations.add((String) value); + } + + // Spark may call this multiple times for UPDATEs with subqueries + // as such cases are rewritten using UNION and the same scan on both sides + // so filter files only if it is beneficial + if (filteredLocations == null || fileLocations.size() < filteredLocations.size()) { + this.filteredLocations = fileLocations; + List filteredTasks = + tasks().stream() + .filter(file -> fileLocations.contains(file.file().location())) + .collect(Collectors.toList()); + + LOG.info( + "{} of {} task(s) for table {} matched runtime file filter with {} location(s)", + filteredTasks.size(), + tasks().size(), + table().name(), + fileLocations.size()); + + resetTasks(filteredTasks); + } + } else { + LOG.warn("Unsupported runtime filter {}", filter); + } + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + SparkCopyOnWriteScan that = (SparkCopyOnWriteScan) o; + return table().name().equals(that.table().name()) + && readSchema().equals(that.readSchema()) // compare Spark schemas to ignore field ids + && filterExpressions().toString().equals(that.filterExpressions().toString()) + && Objects.equals(snapshotId(), that.snapshotId()) + && Objects.equals(filteredLocations, that.filteredLocations); + } + + @Override + public int hashCode() { + return Objects.hash( + table().name(), + readSchema(), + filterExpressions().toString(), + snapshotId(), + filteredLocations); + } + + @Override + public String toString() { + return String.format( + "IcebergCopyOnWriteScan(table=%s, type=%s, filters=%s, caseSensitive=%s)", + table(), expectedSchema().asStruct(), filterExpressions(), caseSensitive()); + } + + private Long currentSnapshotId() { + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table(), branch()); + return currentSnapshot != null ? currentSnapshot.snapshotId() : null; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java new file mode 100644 index 000000000000..50a1259c8626 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java @@ -0,0 +1,296 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.MetadataColumns.DELETE_FILE_ROW_FIELD_NAME; +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT_DEFAULT; +import static org.apache.iceberg.TableProperties.DELETE_DEFAULT_FILE_FORMAT; + +import java.util.Map; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.Table; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.data.BaseFileWriterFactory; +import org.apache.iceberg.io.DeleteSchemaUtil; +import org.apache.iceberg.orc.ORC; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.SparkAvroWriter; +import org.apache.iceberg.spark.data.SparkOrcWriter; +import org.apache.iceberg.spark.data.SparkParquetWriters; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; + +class SparkFileWriterFactory extends BaseFileWriterFactory { + private StructType dataSparkType; + private StructType equalityDeleteSparkType; + private StructType positionDeleteSparkType; + private final Map writeProperties; + + SparkFileWriterFactory( + Table table, + FileFormat dataFileFormat, + Schema dataSchema, + StructType dataSparkType, + SortOrder dataSortOrder, + FileFormat deleteFileFormat, + int[] equalityFieldIds, + Schema equalityDeleteRowSchema, + StructType equalityDeleteSparkType, + SortOrder equalityDeleteSortOrder, + Schema positionDeleteRowSchema, + StructType positionDeleteSparkType, + Map writeProperties) { + + super( + table, + dataFileFormat, + dataSchema, + dataSortOrder, + deleteFileFormat, + equalityFieldIds, + equalityDeleteRowSchema, + equalityDeleteSortOrder, + positionDeleteRowSchema); + + this.dataSparkType = dataSparkType; + this.equalityDeleteSparkType = equalityDeleteSparkType; + this.positionDeleteSparkType = positionDeleteSparkType; + this.writeProperties = writeProperties != null ? writeProperties : ImmutableMap.of(); + } + + static Builder builderFor(Table table) { + return new Builder(table); + } + + @Override + protected void configureDataWrite(Avro.DataWriteBuilder builder) { + builder.createWriterFunc(ignored -> new SparkAvroWriter(dataSparkType())); + builder.setAll(writeProperties); + } + + @Override + protected void configureEqualityDelete(Avro.DeleteWriteBuilder builder) { + builder.createWriterFunc(ignored -> new SparkAvroWriter(equalityDeleteSparkType())); + builder.setAll(writeProperties); + } + + @Override + protected void configurePositionDelete(Avro.DeleteWriteBuilder builder) { + boolean withRow = + positionDeleteSparkType().getFieldIndex(DELETE_FILE_ROW_FIELD_NAME).isDefined(); + if (withRow) { + // SparkAvroWriter accepts just the Spark type of the row ignoring the path and pos + StructField rowField = positionDeleteSparkType().apply(DELETE_FILE_ROW_FIELD_NAME); + StructType positionDeleteRowSparkType = (StructType) rowField.dataType(); + builder.createWriterFunc(ignored -> new SparkAvroWriter(positionDeleteRowSparkType)); + } + + builder.setAll(writeProperties); + } + + @Override + protected void configureDataWrite(Parquet.DataWriteBuilder builder) { + builder.createWriterFunc(msgType -> SparkParquetWriters.buildWriter(dataSparkType(), msgType)); + builder.setAll(writeProperties); + } + + @Override + protected void configureEqualityDelete(Parquet.DeleteWriteBuilder builder) { + builder.createWriterFunc( + msgType -> SparkParquetWriters.buildWriter(equalityDeleteSparkType(), msgType)); + builder.setAll(writeProperties); + } + + @Override + protected void configurePositionDelete(Parquet.DeleteWriteBuilder builder) { + builder.createWriterFunc( + msgType -> SparkParquetWriters.buildWriter(positionDeleteSparkType(), msgType)); + builder.transformPaths(path -> UTF8String.fromString(path.toString())); + builder.setAll(writeProperties); + } + + @Override + protected void configureDataWrite(ORC.DataWriteBuilder builder) { + builder.createWriterFunc(SparkOrcWriter::new); + builder.setAll(writeProperties); + } + + @Override + protected void configureEqualityDelete(ORC.DeleteWriteBuilder builder) { + builder.createWriterFunc(SparkOrcWriter::new); + builder.setAll(writeProperties); + } + + @Override + protected void configurePositionDelete(ORC.DeleteWriteBuilder builder) { + builder.createWriterFunc(SparkOrcWriter::new); + builder.transformPaths(path -> UTF8String.fromString(path.toString())); + builder.setAll(writeProperties); + } + + private StructType dataSparkType() { + if (dataSparkType == null) { + Preconditions.checkNotNull(dataSchema(), "Data schema must not be null"); + this.dataSparkType = SparkSchemaUtil.convert(dataSchema()); + } + + return dataSparkType; + } + + private StructType equalityDeleteSparkType() { + if (equalityDeleteSparkType == null) { + Preconditions.checkNotNull( + equalityDeleteRowSchema(), "Equality delete schema must not be null"); + this.equalityDeleteSparkType = SparkSchemaUtil.convert(equalityDeleteRowSchema()); + } + + return equalityDeleteSparkType; + } + + private StructType positionDeleteSparkType() { + if (positionDeleteSparkType == null) { + // wrap the optional row schema into the position delete schema containing path and position + Schema positionDeleteSchema = DeleteSchemaUtil.posDeleteSchema(positionDeleteRowSchema()); + this.positionDeleteSparkType = SparkSchemaUtil.convert(positionDeleteSchema); + } + + return positionDeleteSparkType; + } + + static class Builder { + private final Table table; + private FileFormat dataFileFormat; + private Schema dataSchema; + private StructType dataSparkType; + private SortOrder dataSortOrder; + private FileFormat deleteFileFormat; + private int[] equalityFieldIds; + private Schema equalityDeleteRowSchema; + private StructType equalityDeleteSparkType; + private SortOrder equalityDeleteSortOrder; + private Schema positionDeleteRowSchema; + private StructType positionDeleteSparkType; + private Map writeProperties; + + Builder(Table table) { + this.table = table; + + Map properties = table.properties(); + + String dataFileFormatName = + properties.getOrDefault(DEFAULT_FILE_FORMAT, DEFAULT_FILE_FORMAT_DEFAULT); + this.dataFileFormat = FileFormat.fromString(dataFileFormatName); + + String deleteFileFormatName = + properties.getOrDefault(DELETE_DEFAULT_FILE_FORMAT, dataFileFormatName); + this.deleteFileFormat = FileFormat.fromString(deleteFileFormatName); + } + + Builder dataFileFormat(FileFormat newDataFileFormat) { + this.dataFileFormat = newDataFileFormat; + return this; + } + + Builder dataSchema(Schema newDataSchema) { + this.dataSchema = newDataSchema; + return this; + } + + Builder dataSparkType(StructType newDataSparkType) { + this.dataSparkType = newDataSparkType; + return this; + } + + Builder dataSortOrder(SortOrder newDataSortOrder) { + this.dataSortOrder = newDataSortOrder; + return this; + } + + Builder deleteFileFormat(FileFormat newDeleteFileFormat) { + this.deleteFileFormat = newDeleteFileFormat; + return this; + } + + Builder equalityFieldIds(int[] newEqualityFieldIds) { + this.equalityFieldIds = newEqualityFieldIds; + return this; + } + + Builder equalityDeleteRowSchema(Schema newEqualityDeleteRowSchema) { + this.equalityDeleteRowSchema = newEqualityDeleteRowSchema; + return this; + } + + Builder equalityDeleteSparkType(StructType newEqualityDeleteSparkType) { + this.equalityDeleteSparkType = newEqualityDeleteSparkType; + return this; + } + + Builder equalityDeleteSortOrder(SortOrder newEqualityDeleteSortOrder) { + this.equalityDeleteSortOrder = newEqualityDeleteSortOrder; + return this; + } + + Builder positionDeleteRowSchema(Schema newPositionDeleteRowSchema) { + this.positionDeleteRowSchema = newPositionDeleteRowSchema; + return this; + } + + Builder positionDeleteSparkType(StructType newPositionDeleteSparkType) { + this.positionDeleteSparkType = newPositionDeleteSparkType; + return this; + } + + Builder writeProperties(Map properties) { + this.writeProperties = properties; + return this; + } + + SparkFileWriterFactory build() { + boolean noEqualityDeleteConf = equalityFieldIds == null && equalityDeleteRowSchema == null; + boolean fullEqualityDeleteConf = equalityFieldIds != null && equalityDeleteRowSchema != null; + Preconditions.checkArgument( + noEqualityDeleteConf || fullEqualityDeleteConf, + "Equality field IDs and equality delete row schema must be set together"); + + return new SparkFileWriterFactory( + table, + dataFileFormat, + dataSchema, + dataSparkType, + dataSortOrder, + deleteFileFormat, + equalityFieldIds, + equalityDeleteRowSchema, + equalityDeleteSparkType, + equalityDeleteSortOrder, + positionDeleteRowSchema, + positionDeleteSparkType, + writeProperties); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkInputPartition.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkInputPartition.java new file mode 100644 index 000000000000..7826322be7de --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkInputPartition.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.io.Serializable; +import org.apache.iceberg.ScanTask; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SchemaParser; +import org.apache.iceberg.Table; +import org.apache.iceberg.types.Types; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.read.HasPartitionKey; +import org.apache.spark.sql.connector.read.InputPartition; + +class SparkInputPartition implements InputPartition, HasPartitionKey, Serializable { + private final Types.StructType groupingKeyType; + private final ScanTaskGroup taskGroup; + private final Broadcast

tableBroadcast; + private final String branch; + private final String expectedSchemaString; + private final boolean caseSensitive; + private final transient String[] preferredLocations; + + private transient Schema expectedSchema = null; + + SparkInputPartition( + Types.StructType groupingKeyType, + ScanTaskGroup taskGroup, + Broadcast
tableBroadcast, + String branch, + String expectedSchemaString, + boolean caseSensitive, + String[] preferredLocations) { + this.groupingKeyType = groupingKeyType; + this.taskGroup = taskGroup; + this.tableBroadcast = tableBroadcast; + this.branch = branch; + this.expectedSchemaString = expectedSchemaString; + this.caseSensitive = caseSensitive; + this.preferredLocations = preferredLocations; + } + + @Override + public String[] preferredLocations() { + return preferredLocations; + } + + @Override + public InternalRow partitionKey() { + return new StructInternalRow(groupingKeyType).setStruct(taskGroup.groupingKey()); + } + + @SuppressWarnings("unchecked") + public ScanTaskGroup taskGroup() { + return (ScanTaskGroup) taskGroup; + } + + public boolean allTasksOfType(Class javaClass) { + return taskGroup.tasks().stream().allMatch(javaClass::isInstance); + } + + public Table table() { + return tableBroadcast.value(); + } + + public String branch() { + return branch; + } + + public boolean isCaseSensitive() { + return caseSensitive; + } + + public Schema expectedSchema() { + if (expectedSchema == null) { + this.expectedSchema = SchemaParser.fromJson(expectedSchemaString); + } + + return expectedSchema; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkLocalScan.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkLocalScan.java new file mode 100644 index 000000000000..c2f9707775dd --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkLocalScan.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.List; +import org.apache.iceberg.Table; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.read.LocalScan; +import org.apache.spark.sql.types.StructType; + +class SparkLocalScan implements LocalScan { + + private final Table table; + private final StructType readSchema; + private final InternalRow[] rows; + private final List filterExpressions; + + SparkLocalScan( + Table table, StructType readSchema, InternalRow[] rows, List filterExpressions) { + this.table = table; + this.readSchema = readSchema; + this.rows = rows; + this.filterExpressions = filterExpressions; + } + + @Override + public InternalRow[] rows() { + return rows; + } + + @Override + public StructType readSchema() { + return readSchema; + } + + @Override + public String description() { + return String.format("%s [filters=%s]", table, Spark3Util.describe(filterExpressions)); + } + + @Override + public String toString() { + return String.format( + "IcebergLocalScan(table=%s, type=%s, filters=%s)", + table, SparkSchemaUtil.convert(readSchema).asStruct(), filterExpressions); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkMetadataColumn.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkMetadataColumn.java new file mode 100644 index 000000000000..94f87c28741d --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkMetadataColumn.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.spark.sql.connector.catalog.MetadataColumn; +import org.apache.spark.sql.types.DataType; + +public class SparkMetadataColumn implements MetadataColumn { + + private final String name; + private final DataType dataType; + private final boolean isNullable; + + public SparkMetadataColumn(String name, DataType dataType, boolean isNullable) { + this.name = name; + this.dataType = dataType; + this.isNullable = isNullable; + } + + @Override + public String name() { + return name; + } + + @Override + public DataType dataType() { + return dataType; + } + + @Override + public boolean isNullable() { + return isNullable; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java new file mode 100644 index 000000000000..49180e07c465 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java @@ -0,0 +1,521 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.io.BufferedWriter; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.OutputStreamWriter; +import java.io.UncheckedIOException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Locale; +import org.apache.iceberg.CombinedScanTask; +import org.apache.iceberg.DataOperations; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.MicroBatches; +import org.apache.iceberg.MicroBatches.MicroBatch; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SchemaParser; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.Table; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.CloseableIterator; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.relocated.com.google.common.base.Joiner; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.Pair; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.iceberg.util.TableScanUtil; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.PartitionReaderFactory; +import org.apache.spark.sql.connector.read.streaming.MicroBatchStream; +import org.apache.spark.sql.connector.read.streaming.Offset; +import org.apache.spark.sql.connector.read.streaming.ReadLimit; +import org.apache.spark.sql.connector.read.streaming.SupportsAdmissionControl; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class SparkMicroBatchStream implements MicroBatchStream, SupportsAdmissionControl { + private static final Joiner SLASH = Joiner.on("/"); + private static final Logger LOG = LoggerFactory.getLogger(SparkMicroBatchStream.class); + private static final Types.StructType EMPTY_GROUPING_KEY_TYPE = Types.StructType.of(); + + private final Table table; + private final String branch; + private final boolean caseSensitive; + private final String expectedSchema; + private final Broadcast
tableBroadcast; + private final long splitSize; + private final int splitLookback; + private final long splitOpenFileCost; + private final boolean localityPreferred; + private final StreamingOffset initialOffset; + private final boolean skipDelete; + private final boolean skipOverwrite; + private final long fromTimestamp; + private final int maxFilesPerMicroBatch; + private final int maxRecordsPerMicroBatch; + + SparkMicroBatchStream( + JavaSparkContext sparkContext, + Table table, + SparkReadConf readConf, + Schema expectedSchema, + String checkpointLocation) { + this.table = table; + this.branch = readConf.branch(); + this.caseSensitive = readConf.caseSensitive(); + this.expectedSchema = SchemaParser.toJson(expectedSchema); + this.localityPreferred = readConf.localityEnabled(); + this.tableBroadcast = sparkContext.broadcast(SerializableTableWithSize.copyOf(table)); + this.splitSize = readConf.splitSize(); + this.splitLookback = readConf.splitLookback(); + this.splitOpenFileCost = readConf.splitOpenFileCost(); + this.fromTimestamp = readConf.streamFromTimestamp(); + this.maxFilesPerMicroBatch = readConf.maxFilesPerMicroBatch(); + this.maxRecordsPerMicroBatch = readConf.maxRecordsPerMicroBatch(); + + InitialOffsetStore initialOffsetStore = + new InitialOffsetStore(table, checkpointLocation, fromTimestamp); + this.initialOffset = initialOffsetStore.initialOffset(); + + this.skipDelete = readConf.streamingSkipDeleteSnapshots(); + this.skipOverwrite = readConf.streamingSkipOverwriteSnapshots(); + } + + @Override + public Offset latestOffset() { + table.refresh(); + if (table.currentSnapshot() == null) { + return StreamingOffset.START_OFFSET; + } + + if (table.currentSnapshot().timestampMillis() < fromTimestamp) { + return StreamingOffset.START_OFFSET; + } + + Snapshot latestSnapshot = table.currentSnapshot(); + + return new StreamingOffset(latestSnapshot.snapshotId(), addedFilesCount(latestSnapshot), false); + } + + @Override + public InputPartition[] planInputPartitions(Offset start, Offset end) { + Preconditions.checkArgument( + end instanceof StreamingOffset, "Invalid end offset: %s is not a StreamingOffset", end); + Preconditions.checkArgument( + start instanceof StreamingOffset, + "Invalid start offset: %s is not a StreamingOffset", + start); + + if (end.equals(StreamingOffset.START_OFFSET)) { + return new InputPartition[0]; + } + + StreamingOffset endOffset = (StreamingOffset) end; + StreamingOffset startOffset = (StreamingOffset) start; + + List fileScanTasks = planFiles(startOffset, endOffset); + + CloseableIterable splitTasks = + TableScanUtil.splitFiles(CloseableIterable.withNoopClose(fileScanTasks), splitSize); + List combinedScanTasks = + Lists.newArrayList( + TableScanUtil.planTasks(splitTasks, splitSize, splitLookback, splitOpenFileCost)); + String[][] locations = computePreferredLocations(combinedScanTasks); + + InputPartition[] partitions = new InputPartition[combinedScanTasks.size()]; + + for (int index = 0; index < combinedScanTasks.size(); index++) { + partitions[index] = + new SparkInputPartition( + EMPTY_GROUPING_KEY_TYPE, + combinedScanTasks.get(index), + tableBroadcast, + branch, + expectedSchema, + caseSensitive, + locations != null ? locations[index] : SparkPlanningUtil.NO_LOCATION_PREFERENCE); + } + + return partitions; + } + + private String[][] computePreferredLocations(List taskGroups) { + return localityPreferred ? SparkPlanningUtil.fetchBlockLocations(table.io(), taskGroups) : null; + } + + @Override + public PartitionReaderFactory createReaderFactory() { + return new SparkRowReaderFactory(); + } + + @Override + public Offset initialOffset() { + return initialOffset; + } + + @Override + public Offset deserializeOffset(String json) { + return StreamingOffset.fromJson(json); + } + + @Override + public void commit(Offset end) {} + + @Override + public void stop() {} + + private List planFiles(StreamingOffset startOffset, StreamingOffset endOffset) { + List fileScanTasks = Lists.newArrayList(); + StreamingOffset batchStartOffset = + StreamingOffset.START_OFFSET.equals(startOffset) + ? determineStartingOffset(table, fromTimestamp) + : startOffset; + + StreamingOffset currentOffset = null; + + // [(startOffset : startFileIndex), (endOffset : endFileIndex) ) + do { + long endFileIndex; + if (currentOffset == null) { + currentOffset = batchStartOffset; + } else { + Snapshot snapshotAfter = SnapshotUtil.snapshotAfter(table, currentOffset.snapshotId()); + // it may happen that we need to read this snapshot partially in case it's equal to + // endOffset. + if (currentOffset.snapshotId() != endOffset.snapshotId()) { + currentOffset = new StreamingOffset(snapshotAfter.snapshotId(), 0L, false); + } else { + currentOffset = endOffset; + } + } + + Snapshot snapshot = table.snapshot(currentOffset.snapshotId()); + + validateCurrentSnapshotExists(snapshot, currentOffset); + + if (!shouldProcess(snapshot)) { + LOG.debug("Skipping snapshot: {} of table {}", currentOffset.snapshotId(), table.name()); + continue; + } + + Snapshot currentSnapshot = table.snapshot(currentOffset.snapshotId()); + if (currentOffset.snapshotId() == endOffset.snapshotId()) { + endFileIndex = endOffset.position(); + } else { + endFileIndex = addedFilesCount(currentSnapshot); + } + + MicroBatch latestMicroBatch = + MicroBatches.from(currentSnapshot, table.io()) + .caseSensitive(caseSensitive) + .specsById(table.specs()) + .generate( + currentOffset.position(), + endFileIndex, + Long.MAX_VALUE, + currentOffset.shouldScanAllFiles()); + + fileScanTasks.addAll(latestMicroBatch.tasks()); + } while (currentOffset.snapshotId() != endOffset.snapshotId()); + + return fileScanTasks; + } + + private boolean shouldProcess(Snapshot snapshot) { + String op = snapshot.operation(); + switch (op) { + case DataOperations.APPEND: + return true; + case DataOperations.REPLACE: + return false; + case DataOperations.DELETE: + Preconditions.checkState( + skipDelete, + "Cannot process delete snapshot: %s, to ignore deletes, set %s=true", + snapshot.snapshotId(), + SparkReadOptions.STREAMING_SKIP_DELETE_SNAPSHOTS); + return false; + case DataOperations.OVERWRITE: + Preconditions.checkState( + skipOverwrite, + "Cannot process overwrite snapshot: %s, to ignore overwrites, set %s=true", + snapshot.snapshotId(), + SparkReadOptions.STREAMING_SKIP_OVERWRITE_SNAPSHOTS); + return false; + default: + throw new IllegalStateException( + String.format( + "Cannot process unknown snapshot operation: %s (snapshot id %s)", + op.toLowerCase(Locale.ROOT), snapshot.snapshotId())); + } + } + + private static StreamingOffset determineStartingOffset(Table table, Long fromTimestamp) { + if (table.currentSnapshot() == null) { + return StreamingOffset.START_OFFSET; + } + + if (fromTimestamp == null) { + // match existing behavior and start from the oldest snapshot + return new StreamingOffset(SnapshotUtil.oldestAncestor(table).snapshotId(), 0, false); + } + + if (table.currentSnapshot().timestampMillis() < fromTimestamp) { + return StreamingOffset.START_OFFSET; + } + + try { + Snapshot snapshot = SnapshotUtil.oldestAncestorAfter(table, fromTimestamp); + if (snapshot != null) { + return new StreamingOffset(snapshot.snapshotId(), 0, false); + } else { + return StreamingOffset.START_OFFSET; + } + } catch (IllegalStateException e) { + // could not determine the first snapshot after the timestamp. use the oldest ancestor instead + return new StreamingOffset(SnapshotUtil.oldestAncestor(table).snapshotId(), 0, false); + } + } + + @Override + @SuppressWarnings("checkstyle:CyclomaticComplexity") + public Offset latestOffset(Offset startOffset, ReadLimit limit) { + // calculate end offset get snapshotId from the startOffset + Preconditions.checkArgument( + startOffset instanceof StreamingOffset, + "Invalid start offset: %s is not a StreamingOffset", + startOffset); + + table.refresh(); + if (table.currentSnapshot() == null) { + return StreamingOffset.START_OFFSET; + } + + if (table.currentSnapshot().timestampMillis() < fromTimestamp) { + return StreamingOffset.START_OFFSET; + } + + // end offset can expand to multiple snapshots + StreamingOffset startingOffset = (StreamingOffset) startOffset; + + if (startOffset.equals(StreamingOffset.START_OFFSET)) { + startingOffset = determineStartingOffset(table, fromTimestamp); + } + + Snapshot curSnapshot = table.snapshot(startingOffset.snapshotId()); + validateCurrentSnapshotExists(curSnapshot, startingOffset); + + int startPosOfSnapOffset = (int) startingOffset.position(); + + boolean scanAllFiles = startingOffset.shouldScanAllFiles(); + + boolean shouldContinueReading = true; + int curFilesAdded = 0; + int curRecordCount = 0; + int curPos = 0; + + // Note : we produce nextOffset with pos as non-inclusive + while (shouldContinueReading) { + // generate manifest index for the curSnapshot + List> indexedManifests = + MicroBatches.skippedManifestIndexesFromSnapshot( + table.io(), curSnapshot, startPosOfSnapOffset, scanAllFiles); + // this is under assumption we will be able to add at-least 1 file in the new offset + for (int idx = 0; idx < indexedManifests.size() && shouldContinueReading; idx++) { + // be rest assured curPos >= startFileIndex + curPos = indexedManifests.get(idx).second(); + try (CloseableIterable taskIterable = + MicroBatches.openManifestFile( + table.io(), + table.specs(), + caseSensitive, + curSnapshot, + indexedManifests.get(idx).first(), + scanAllFiles); + CloseableIterator taskIter = taskIterable.iterator()) { + while (taskIter.hasNext()) { + FileScanTask task = taskIter.next(); + if (curPos >= startPosOfSnapOffset) { + // TODO : use readLimit provided in function param, the readLimits are derived from + // these 2 properties. + if ((curFilesAdded + 1) > maxFilesPerMicroBatch + || (curRecordCount + task.file().recordCount()) > maxRecordsPerMicroBatch) { + shouldContinueReading = false; + break; + } + + curFilesAdded += 1; + curRecordCount += task.file().recordCount(); + } + ++curPos; + } + } catch (IOException ioe) { + LOG.warn("Failed to close task iterable", ioe); + } + } + // if the currentSnapShot was also the mostRecentSnapshot then break + if (curSnapshot.snapshotId() == table.currentSnapshot().snapshotId()) { + break; + } + + // if everything was OK and we consumed complete snapshot then move to next snapshot + if (shouldContinueReading) { + Snapshot nextValid = nextValidSnapshot(curSnapshot); + if (nextValid == null) { + // nextValide implies all the remaining snapshots should be skipped. + break; + } + // we found the next available snapshot, continue from there. + curSnapshot = nextValid; + startPosOfSnapOffset = -1; + // if anyhow we are moving to next snapshot we should only scan addedFiles + scanAllFiles = false; + } + } + + StreamingOffset latestStreamingOffset = + new StreamingOffset(curSnapshot.snapshotId(), curPos, scanAllFiles); + + // if no new data arrived, then return null. + return latestStreamingOffset.equals(startingOffset) ? null : latestStreamingOffset; + } + + /** + * Get the next snapshot skiping over rewrite and delete snapshots. + * + * @param curSnapshot the current snapshot + * @return the next valid snapshot (not a rewrite or delete snapshot), returns null if all + * remaining snapshots should be skipped. + */ + private Snapshot nextValidSnapshot(Snapshot curSnapshot) { + Snapshot nextSnapshot = SnapshotUtil.snapshotAfter(table, curSnapshot.snapshotId()); + // skip over rewrite and delete snapshots + while (!shouldProcess(nextSnapshot)) { + LOG.debug("Skipping snapshot: {} of table {}", nextSnapshot.snapshotId(), table.name()); + // if the currentSnapShot was also the mostRecentSnapshot then break + if (nextSnapshot.snapshotId() == table.currentSnapshot().snapshotId()) { + return null; + } + nextSnapshot = SnapshotUtil.snapshotAfter(table, nextSnapshot.snapshotId()); + } + return nextSnapshot; + } + + private long addedFilesCount(Snapshot snapshot) { + long addedFilesCount = + PropertyUtil.propertyAsLong(snapshot.summary(), SnapshotSummary.ADDED_FILES_PROP, -1); + // If snapshotSummary doesn't have SnapshotSummary.ADDED_FILES_PROP, + // iterate through addedFiles iterator to find addedFilesCount. + return addedFilesCount == -1 + ? Iterables.size(snapshot.addedDataFiles(table.io())) + : addedFilesCount; + } + + private void validateCurrentSnapshotExists(Snapshot snapshot, StreamingOffset currentOffset) { + if (snapshot == null) { + throw new IllegalStateException( + String.format( + "Cannot load current offset at snapshot %d, the snapshot was expired or removed", + currentOffset.snapshotId())); + } + } + + @Override + public ReadLimit getDefaultReadLimit() { + if (maxFilesPerMicroBatch != Integer.MAX_VALUE + && maxRecordsPerMicroBatch != Integer.MAX_VALUE) { + ReadLimit[] readLimits = new ReadLimit[2]; + readLimits[0] = ReadLimit.maxFiles(maxFilesPerMicroBatch); + readLimits[1] = ReadLimit.maxRows(maxFilesPerMicroBatch); + return ReadLimit.compositeLimit(readLimits); + } else if (maxFilesPerMicroBatch != Integer.MAX_VALUE) { + return ReadLimit.maxFiles(maxFilesPerMicroBatch); + } else if (maxRecordsPerMicroBatch != Integer.MAX_VALUE) { + return ReadLimit.maxRows(maxRecordsPerMicroBatch); + } else { + return ReadLimit.allAvailable(); + } + } + + private static class InitialOffsetStore { + private final Table table; + private final FileIO io; + private final String initialOffsetLocation; + private final Long fromTimestamp; + + InitialOffsetStore(Table table, String checkpointLocation, Long fromTimestamp) { + this.table = table; + this.io = table.io(); + this.initialOffsetLocation = SLASH.join(checkpointLocation, "offsets/0"); + this.fromTimestamp = fromTimestamp; + } + + public StreamingOffset initialOffset() { + InputFile inputFile = io.newInputFile(initialOffsetLocation); + if (inputFile.exists()) { + return readOffset(inputFile); + } + + table.refresh(); + StreamingOffset offset = determineStartingOffset(table, fromTimestamp); + + OutputFile outputFile = io.newOutputFile(initialOffsetLocation); + writeOffset(offset, outputFile); + + return offset; + } + + private void writeOffset(StreamingOffset offset, OutputFile file) { + try (OutputStream outputStream = file.create()) { + BufferedWriter writer = + new BufferedWriter(new OutputStreamWriter(outputStream, StandardCharsets.UTF_8)); + writer.write(offset.json()); + writer.flush(); + } catch (IOException ioException) { + throw new UncheckedIOException( + String.format("Failed writing offset to: %s", initialOffsetLocation), ioException); + } + } + + private StreamingOffset readOffset(InputFile file) { + try (InputStream in = file.newStream()) { + return StreamingOffset.fromJson(in); + } catch (IOException ioException) { + throw new UncheckedIOException( + String.format("Failed reading offset from: %s", initialOffsetLocation), ioException); + } + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkPartitionedFanoutWriter.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkPartitionedFanoutWriter.java new file mode 100644 index 000000000000..c34ad2f3ad4a --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkPartitionedFanoutWriter.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.PartitionKey; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileAppenderFactory; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.OutputFileFactory; +import org.apache.iceberg.io.PartitionedFanoutWriter; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructType; + +public class SparkPartitionedFanoutWriter extends PartitionedFanoutWriter { + private final PartitionKey partitionKey; + private final InternalRowWrapper internalRowWrapper; + + public SparkPartitionedFanoutWriter( + PartitionSpec spec, + FileFormat format, + FileAppenderFactory appenderFactory, + OutputFileFactory fileFactory, + FileIO io, + long targetFileSize, + Schema schema, + StructType sparkSchema) { + super(spec, format, appenderFactory, fileFactory, io, targetFileSize); + this.partitionKey = new PartitionKey(spec, schema); + this.internalRowWrapper = new InternalRowWrapper(sparkSchema, schema.asStruct()); + } + + @Override + protected PartitionKey partition(InternalRow row) { + partitionKey.partition(internalRowWrapper.wrap(row)); + return partitionKey; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkPartitionedWriter.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkPartitionedWriter.java new file mode 100644 index 000000000000..6904446829e4 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkPartitionedWriter.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.PartitionKey; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileAppenderFactory; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.OutputFileFactory; +import org.apache.iceberg.io.PartitionedWriter; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructType; + +public class SparkPartitionedWriter extends PartitionedWriter { + private final PartitionKey partitionKey; + private final InternalRowWrapper internalRowWrapper; + + public SparkPartitionedWriter( + PartitionSpec spec, + FileFormat format, + FileAppenderFactory appenderFactory, + OutputFileFactory fileFactory, + FileIO io, + long targetFileSize, + Schema schema, + StructType sparkSchema) { + super(spec, format, appenderFactory, fileFactory, io, targetFileSize); + this.partitionKey = new PartitionKey(spec, schema); + this.internalRowWrapper = new InternalRowWrapper(sparkSchema, schema.asStruct()); + } + + @Override + protected PartitionKey partition(InternalRow row) { + partitionKey.partition(internalRowWrapper.wrap(row)); + return partitionKey; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkPartitioningAwareScan.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkPartitioningAwareScan.java new file mode 100644 index 000000000000..141dd4dcba0e --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkPartitioningAwareScan.java @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.apache.iceberg.PartitionField; +import org.apache.iceberg.PartitionScanTask; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Scan; +import org.apache.iceberg.ScanTask; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.metrics.ScanReport; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.iceberg.types.Types.StructType; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.iceberg.util.StructLikeSet; +import org.apache.iceberg.util.TableScanUtil; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.connector.read.SupportsReportPartitioning; +import org.apache.spark.sql.connector.read.partitioning.KeyGroupedPartitioning; +import org.apache.spark.sql.connector.read.partitioning.Partitioning; +import org.apache.spark.sql.connector.read.partitioning.UnknownPartitioning; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +abstract class SparkPartitioningAwareScan extends SparkScan + implements SupportsReportPartitioning { + + private static final Logger LOG = LoggerFactory.getLogger(SparkPartitioningAwareScan.class); + + private final Scan> scan; + private final boolean preserveDataGrouping; + + private Set specs = null; // lazy cache of scanned specs + private List tasks = null; // lazy cache of uncombined tasks + private List> taskGroups = null; // lazy cache of task groups + private StructType groupingKeyType = null; // lazy cache of the grouping key type + private Transform[] groupingKeyTransforms = null; // lazy cache of grouping key transforms + + SparkPartitioningAwareScan( + SparkSession spark, + Table table, + Scan> scan, + SparkReadConf readConf, + Schema expectedSchema, + List filters, + Supplier scanReportSupplier) { + super(spark, table, readConf, expectedSchema, filters, scanReportSupplier); + + this.scan = scan; + this.preserveDataGrouping = readConf.preserveDataGrouping(); + + if (scan == null) { + this.specs = Collections.emptySet(); + this.tasks = Collections.emptyList(); + this.taskGroups = Collections.emptyList(); + } + } + + protected abstract Class taskJavaClass(); + + protected Scan> scan() { + return scan; + } + + @Override + public Partitioning outputPartitioning() { + if (groupingKeyType().fields().isEmpty()) { + LOG.info( + "Reporting UnknownPartitioning with {} partition(s) for table {}", + taskGroups().size(), + table().name()); + return new UnknownPartitioning(taskGroups().size()); + } else { + LOG.info( + "Reporting KeyGroupedPartitioning by {} with {} partition(s) for table {}", + groupingKeyTransforms(), + taskGroups().size(), + table().name()); + return new KeyGroupedPartitioning(groupingKeyTransforms(), taskGroups().size()); + } + } + + @Override + protected StructType groupingKeyType() { + if (groupingKeyType == null) { + if (preserveDataGrouping) { + this.groupingKeyType = computeGroupingKeyType(); + } else { + this.groupingKeyType = StructType.of(); + } + } + + return groupingKeyType; + } + + private StructType computeGroupingKeyType() { + return org.apache.iceberg.Partitioning.groupingKeyType(expectedSchema(), specs()); + } + + private Transform[] groupingKeyTransforms() { + if (groupingKeyTransforms == null) { + Map fieldsById = indexFieldsById(specs()); + + List groupingKeyFields = + groupingKeyType().fields().stream() + .map(field -> fieldsById.get(field.fieldId())) + .collect(Collectors.toList()); + + Schema schema = SnapshotUtil.schemaFor(table(), branch()); + this.groupingKeyTransforms = Spark3Util.toTransforms(schema, groupingKeyFields); + } + + return groupingKeyTransforms; + } + + private Map indexFieldsById(Iterable specIterable) { + Map fieldsById = Maps.newHashMap(); + + for (PartitionSpec spec : specIterable) { + for (PartitionField field : spec.fields()) { + fieldsById.putIfAbsent(field.fieldId(), field); + } + } + + return fieldsById; + } + + protected Set specs() { + if (specs == null) { + // avoid calling equals/hashCode on specs as those methods are relatively expensive + IntStream specIds = tasks().stream().mapToInt(task -> task.spec().specId()).distinct(); + this.specs = specIds.mapToObj(id -> table().specs().get(id)).collect(Collectors.toSet()); + } + + return specs; + } + + protected synchronized List tasks() { + if (tasks == null) { + try (CloseableIterable taskIterable = scan.planFiles()) { + List plannedTasks = Lists.newArrayList(); + + for (ScanTask task : taskIterable) { + ValidationException.check( + taskJavaClass().isInstance(task), + "Unsupported task type, expected a subtype of %s: %", + taskJavaClass().getName(), + task.getClass().getName()); + + plannedTasks.add(taskJavaClass().cast(task)); + } + + this.tasks = plannedTasks; + } catch (IOException e) { + throw new UncheckedIOException("Failed to close scan: " + scan, e); + } + } + + return tasks; + } + + @Override + protected synchronized List> taskGroups() { + if (taskGroups == null) { + if (groupingKeyType().fields().isEmpty()) { + CloseableIterable> plannedTaskGroups = + TableScanUtil.planTaskGroups( + CloseableIterable.withNoopClose(tasks()), + adjustSplitSize(tasks(), scan.targetSplitSize()), + scan.splitLookback(), + scan.splitOpenFileCost()); + this.taskGroups = Lists.newArrayList(plannedTaskGroups); + + LOG.debug( + "Planned {} task group(s) without data grouping for table {}", + taskGroups.size(), + table().name()); + + } else { + List> plannedTaskGroups = + TableScanUtil.planTaskGroups( + tasks(), + adjustSplitSize(tasks(), scan.targetSplitSize()), + scan.splitLookback(), + scan.splitOpenFileCost(), + groupingKeyType()); + StructLikeSet plannedGroupingKeys = collectGroupingKeys(plannedTaskGroups); + + LOG.debug( + "Planned {} task group(s) with {} grouping key type and {} unique grouping key(s) for table {}", + plannedTaskGroups.size(), + groupingKeyType(), + plannedGroupingKeys.size(), + table().name()); + + this.taskGroups = plannedTaskGroups; + } + } + + return taskGroups; + } + + // only task groups can be reset while resetting tasks + // the set of scanned specs and grouping key type must never change + protected void resetTasks(List filteredTasks) { + this.taskGroups = null; + this.tasks = filteredTasks; + } + + private StructLikeSet collectGroupingKeys(Iterable> taskGroupIterable) { + StructLikeSet keys = StructLikeSet.create(groupingKeyType()); + + for (ScanTaskGroup taskGroup : taskGroupIterable) { + keys.add(taskGroup.groupingKey()); + } + + return keys; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkPlanningUtil.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkPlanningUtil.java new file mode 100644 index 000000000000..9cdec2c8f463 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkPlanningUtil.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.ScanTask; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.hadoop.Util; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.JavaHash; +import org.apache.iceberg.util.Tasks; +import org.apache.iceberg.util.ThreadPools; + +class SparkPlanningUtil { + + public static final String[] NO_LOCATION_PREFERENCE = new String[0]; + + private SparkPlanningUtil() {} + + public static String[][] fetchBlockLocations( + FileIO io, List> taskGroups) { + String[][] locations = new String[taskGroups.size()][]; + + Tasks.range(taskGroups.size()) + .stopOnFailure() + .executeWith(ThreadPools.getWorkerPool()) + .run(index -> locations[index] = Util.blockLocations(io, taskGroups.get(index))); + + return locations; + } + + public static String[][] assignExecutors( + List> taskGroups, List executorLocations) { + Map> partitionHashes = Maps.newHashMap(); + String[][] locations = new String[taskGroups.size()][]; + + for (int index = 0; index < taskGroups.size(); index++) { + locations[index] = assign(taskGroups.get(index), executorLocations, partitionHashes); + } + + return locations; + } + + private static String[] assign( + ScanTaskGroup taskGroup, + List executorLocations, + Map> partitionHashes) { + List locations = Lists.newArrayList(); + + for (ScanTask task : taskGroup.tasks()) { + if (task.isFileScanTask()) { + FileScanTask fileTask = task.asFileScanTask(); + PartitionSpec spec = fileTask.spec(); + if (spec.isPartitioned() && !fileTask.deletes().isEmpty()) { + JavaHash partitionHash = + partitionHashes.computeIfAbsent(spec.specId(), key -> partitionHash(spec)); + int partitionHashCode = partitionHash.hash(fileTask.partition()); + int index = Math.floorMod(partitionHashCode, executorLocations.size()); + String executorLocation = executorLocations.get(index); + locations.add(executorLocation); + } + } + } + + return locations.toArray(NO_LOCATION_PREFERENCE); + } + + private static JavaHash partitionHash(PartitionSpec spec) { + return JavaHash.forType(spec.partitionType()); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeletesRewrite.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeletesRewrite.java new file mode 100644 index 000000000000..73e6ab01563c --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeletesRewrite.java @@ -0,0 +1,438 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.PositionDeletesTable; +import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.deletes.DeleteGranularity; +import org.apache.iceberg.deletes.PositionDelete; +import org.apache.iceberg.io.ClusteredPositionDeleteWriter; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.OutputFileFactory; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.PositionDeletesRewriteCoordinator; +import org.apache.iceberg.spark.ScanTaskSetManager; +import org.apache.iceberg.spark.SparkWriteConf; +import org.apache.iceberg.util.DeleteFileSet; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.write.BatchWrite; +import org.apache.spark.sql.connector.write.DataWriter; +import org.apache.spark.sql.connector.write.DataWriterFactory; +import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.PhysicalWriteInfo; +import org.apache.spark.sql.connector.write.Write; +import org.apache.spark.sql.connector.write.WriterCommitMessage; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * {@link Write} class for rewriting position delete files from Spark. Responsible for creating + * {@link SparkPositionDeletesRewrite.PositionDeleteBatchWrite} + * + *

This class is meant to be used for an action to rewrite position delete files. Hence, it + * assumes all position deletes to rewrite have come from {@link ScanTaskSetManager} and that all + * have the same partition spec id and partition values. + */ +public class SparkPositionDeletesRewrite implements Write { + + private final JavaSparkContext sparkContext; + private final Table table; + private final String queryId; + private final FileFormat format; + private final long targetFileSize; + private final DeleteGranularity deleteGranularity; + private final Schema writeSchema; + private final StructType dsSchema; + private final String fileSetId; + private final int specId; + private final StructLike partition; + private final Map writeProperties; + + /** + * Constructs a {@link SparkPositionDeletesRewrite}. + * + * @param spark Spark session + * @param table instance of {@link PositionDeletesTable} + * @param writeConf Spark write config + * @param writeInfo Spark write info + * @param writeSchema Iceberg output schema + * @param dsSchema schema of original incoming position deletes dataset + * @param specId spec id of position deletes + * @param partition partition value of position deletes + */ + SparkPositionDeletesRewrite( + SparkSession spark, + Table table, + SparkWriteConf writeConf, + LogicalWriteInfo writeInfo, + Schema writeSchema, + StructType dsSchema, + int specId, + StructLike partition) { + this.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + this.table = table; + this.queryId = writeInfo.queryId(); + this.format = writeConf.deleteFileFormat(); + this.targetFileSize = writeConf.targetDeleteFileSize(); + this.deleteGranularity = writeConf.deleteGranularity(); + this.writeSchema = writeSchema; + this.dsSchema = dsSchema; + this.fileSetId = writeConf.rewrittenFileSetId(); + this.specId = specId; + this.partition = partition; + this.writeProperties = writeConf.writeProperties(); + } + + @Override + public BatchWrite toBatch() { + return new PositionDeleteBatchWrite(); + } + + /** {@link BatchWrite} class for rewriting position deletes files from Spark */ + class PositionDeleteBatchWrite implements BatchWrite { + + @Override + public DataWriterFactory createBatchWriterFactory(PhysicalWriteInfo info) { + // broadcast the table metadata as the writer factory will be sent to executors + Broadcast

tableBroadcast = + sparkContext.broadcast(SerializableTableWithSize.copyOf(table)); + return new PositionDeletesWriterFactory( + tableBroadcast, + queryId, + format, + targetFileSize, + deleteGranularity, + writeSchema, + dsSchema, + specId, + partition, + writeProperties); + } + + @Override + public boolean useCommitCoordinator() { + return false; + } + + @Override + public void commit(WriterCommitMessage[] messages) { + PositionDeletesRewriteCoordinator coordinator = PositionDeletesRewriteCoordinator.get(); + coordinator.stageRewrite(table, fileSetId, DeleteFileSet.of(files(messages))); + } + + @Override + public void abort(WriterCommitMessage[] messages) { + SparkCleanupUtil.deleteFiles("job abort", table.io(), files(messages)); + } + + private List files(WriterCommitMessage[] messages) { + List files = Lists.newArrayList(); + + for (WriterCommitMessage message : messages) { + if (message != null) { + DeleteTaskCommit taskCommit = (DeleteTaskCommit) message; + files.addAll(Arrays.asList(taskCommit.files())); + } + } + + return files; + } + } + + /** + * Writer factory for position deletes metadata table. Responsible for creating {@link + * DeleteWriter}. + * + *

This writer is meant to be used for an action to rewrite delete files. Hence, it makes an + * assumption that all incoming deletes belong to the same partition, and that incoming dataset is + * from {@link ScanTaskSetManager}. + */ + static class PositionDeletesWriterFactory implements DataWriterFactory { + private final Broadcast

tableBroadcast; + private final String queryId; + private final FileFormat format; + private final Long targetFileSize; + private final DeleteGranularity deleteGranularity; + private final Schema writeSchema; + private final StructType dsSchema; + private final int specId; + private final StructLike partition; + private final Map writeProperties; + + PositionDeletesWriterFactory( + Broadcast
tableBroadcast, + String queryId, + FileFormat format, + long targetFileSize, + DeleteGranularity deleteGranularity, + Schema writeSchema, + StructType dsSchema, + int specId, + StructLike partition, + Map writeProperties) { + this.tableBroadcast = tableBroadcast; + this.queryId = queryId; + this.format = format; + this.targetFileSize = targetFileSize; + this.deleteGranularity = deleteGranularity; + this.writeSchema = writeSchema; + this.dsSchema = dsSchema; + this.specId = specId; + this.partition = partition; + this.writeProperties = writeProperties; + } + + @Override + public DataWriter createWriter(int partitionId, long taskId) { + Table table = tableBroadcast.value(); + + OutputFileFactory deleteFileFactory = + OutputFileFactory.builderFor(table, partitionId, taskId) + .format(format) + .operationId(queryId) + .suffix("deletes") + .build(); + + Schema positionDeleteRowSchema = positionDeleteRowSchema(); + StructType deleteSparkType = deleteSparkType(); + StructType deleteSparkTypeWithoutRow = deleteSparkTypeWithoutRow(); + + SparkFileWriterFactory writerFactoryWithRow = + SparkFileWriterFactory.builderFor(table) + .deleteFileFormat(format) + .positionDeleteRowSchema(positionDeleteRowSchema) + .positionDeleteSparkType(deleteSparkType) + .writeProperties(writeProperties) + .build(); + SparkFileWriterFactory writerFactoryWithoutRow = + SparkFileWriterFactory.builderFor(table) + .deleteFileFormat(format) + .positionDeleteSparkType(deleteSparkTypeWithoutRow) + .writeProperties(writeProperties) + .build(); + + return new DeleteWriter( + table, + writerFactoryWithRow, + writerFactoryWithoutRow, + deleteFileFactory, + targetFileSize, + deleteGranularity, + dsSchema, + specId, + partition); + } + + private Schema positionDeleteRowSchema() { + return new Schema( + writeSchema + .findField(MetadataColumns.DELETE_FILE_ROW_FIELD_NAME) + .type() + .asStructType() + .fields()); + } + + private StructType deleteSparkType() { + return new StructType( + new StructField[] { + dsSchema.apply(MetadataColumns.DELETE_FILE_PATH.name()), + dsSchema.apply(MetadataColumns.DELETE_FILE_POS.name()), + dsSchema.apply(MetadataColumns.DELETE_FILE_ROW_FIELD_NAME) + }); + } + + private StructType deleteSparkTypeWithoutRow() { + return new StructType( + new StructField[] { + dsSchema.apply(MetadataColumns.DELETE_FILE_PATH.name()), + dsSchema.apply(MetadataColumns.DELETE_FILE_POS.name()), + }); + } + } + + /** + * Writer for position deletes metadata table. + * + *

Iceberg specifies delete files schema as having either 'row' as a required field, or omits + * 'row' altogether. This is to ensure accuracy of delete file statistics on 'row' column. Hence, + * this writer, if receiving source position deletes with null and non-null rows, redirects rows + * with null 'row' to one file writer, and non-null 'row' to another file writer. + * + *

This writer is meant to be used for an action to rewrite delete files. Hence, it makes an + * assumption that all incoming deletes belong to the same partition. + */ + private static class DeleteWriter implements DataWriter { + private final SparkFileWriterFactory writerFactoryWithRow; + private final SparkFileWriterFactory writerFactoryWithoutRow; + private final OutputFileFactory deleteFileFactory; + private final long targetFileSize; + private final DeleteGranularity deleteGranularity; + private final PositionDelete positionDelete; + private final FileIO io; + private final PartitionSpec spec; + private final int fileOrdinal; + private final int positionOrdinal; + private final int rowOrdinal; + private final int rowSize; + private final StructLike partition; + + private ClusteredPositionDeleteWriter writerWithRow; + private ClusteredPositionDeleteWriter writerWithoutRow; + private boolean closed = false; + + /** + * Constructs a {@link DeleteWriter}. + * + * @param table position deletes metadata table + * @param writerFactoryWithRow writer factory for deletes with non-null 'row' + * @param writerFactoryWithoutRow writer factory for deletes with null 'row' + * @param deleteFileFactory delete file factory + * @param targetFileSize target file size + * @param dsSchema schema of incoming dataset of position deletes + * @param specId partition spec id of incoming position deletes. All incoming partition deletes + * are required to have the same spec id. + * @param partition partition value of incoming position delete. All incoming partition deletes + * are required to have the same partition. + */ + DeleteWriter( + Table table, + SparkFileWriterFactory writerFactoryWithRow, + SparkFileWriterFactory writerFactoryWithoutRow, + OutputFileFactory deleteFileFactory, + long targetFileSize, + DeleteGranularity deleteGranularity, + StructType dsSchema, + int specId, + StructLike partition) { + this.deleteFileFactory = deleteFileFactory; + this.targetFileSize = targetFileSize; + this.deleteGranularity = deleteGranularity; + this.writerFactoryWithRow = writerFactoryWithRow; + this.writerFactoryWithoutRow = writerFactoryWithoutRow; + this.positionDelete = PositionDelete.create(); + this.io = table.io(); + this.spec = table.specs().get(specId); + this.partition = partition; + + this.fileOrdinal = dsSchema.fieldIndex(MetadataColumns.DELETE_FILE_PATH.name()); + this.positionOrdinal = dsSchema.fieldIndex(MetadataColumns.DELETE_FILE_POS.name()); + + this.rowOrdinal = dsSchema.fieldIndex(MetadataColumns.DELETE_FILE_ROW_FIELD_NAME); + DataType type = dsSchema.apply(MetadataColumns.DELETE_FILE_ROW_FIELD_NAME).dataType(); + Preconditions.checkArgument( + type instanceof StructType, "Expected row as struct type but was %s", type); + this.rowSize = ((StructType) type).size(); + } + + @Override + public void write(InternalRow record) throws IOException { + String file = record.getString(fileOrdinal); + long position = record.getLong(positionOrdinal); + InternalRow row = record.getStruct(rowOrdinal, rowSize); + if (row != null) { + positionDelete.set(file, position, row); + lazyWriterWithRow().write(positionDelete, spec, partition); + } else { + positionDelete.set(file, position, null); + lazyWriterWithoutRow().write(positionDelete, spec, partition); + } + } + + @Override + public WriterCommitMessage commit() throws IOException { + close(); + return new DeleteTaskCommit(allDeleteFiles()); + } + + @Override + public void abort() throws IOException { + close(); + SparkCleanupUtil.deleteTaskFiles(io, allDeleteFiles()); + } + + @Override + public void close() throws IOException { + if (!closed) { + if (writerWithRow != null) { + writerWithRow.close(); + } + if (writerWithoutRow != null) { + writerWithoutRow.close(); + } + this.closed = true; + } + } + + private ClusteredPositionDeleteWriter lazyWriterWithRow() { + if (writerWithRow == null) { + this.writerWithRow = + new ClusteredPositionDeleteWriter<>( + writerFactoryWithRow, deleteFileFactory, io, targetFileSize, deleteGranularity); + } + return writerWithRow; + } + + private ClusteredPositionDeleteWriter lazyWriterWithoutRow() { + if (writerWithoutRow == null) { + this.writerWithoutRow = + new ClusteredPositionDeleteWriter<>( + writerFactoryWithoutRow, deleteFileFactory, io, targetFileSize, deleteGranularity); + } + return writerWithoutRow; + } + + private List allDeleteFiles() { + List allDeleteFiles = Lists.newArrayList(); + if (writerWithRow != null) { + allDeleteFiles.addAll(writerWithRow.result().deleteFiles()); + } + if (writerWithoutRow != null) { + allDeleteFiles.addAll(writerWithoutRow.result().deleteFiles()); + } + return allDeleteFiles; + } + } + + public static class DeleteTaskCommit implements WriterCommitMessage { + private final DeleteFile[] taskFiles; + + DeleteTaskCommit(List deleteFiles) { + this.taskFiles = deleteFiles.toArray(new DeleteFile[0]); + } + + DeleteFile[] files() { + return taskFiles; + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeletesRewriteBuilder.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeletesRewriteBuilder.java new file mode 100644 index 000000000000..9fccc05ea25c --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeletesRewriteBuilder.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.iceberg.ContentScanTask; +import org.apache.iceberg.PositionDeletesScanTask; +import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.base.Joiner; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.ScanTaskSetManager; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkWriteConf; +import org.apache.iceberg.util.StructLikeSet; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.Write; +import org.apache.spark.sql.connector.write.WriteBuilder; +import org.apache.spark.sql.types.StructType; + +/** + * Builder class for rewrites of position delete files from Spark. Responsible for creating {@link + * SparkPositionDeletesRewrite}. + * + *

This class is meant to be used for an action to rewrite delete files. Hence, it makes an + * assumption that all incoming deletes belong to the same partition, and that incoming dataset is + * from {@link ScanTaskSetManager}. + */ +public class SparkPositionDeletesRewriteBuilder implements WriteBuilder { + + private final SparkSession spark; + private final Table table; + private final SparkWriteConf writeConf; + private final LogicalWriteInfo writeInfo; + private final StructType dsSchema; + private final Schema writeSchema; + + SparkPositionDeletesRewriteBuilder( + SparkSession spark, Table table, String branch, LogicalWriteInfo info) { + this.spark = spark; + this.table = table; + this.writeConf = new SparkWriteConf(spark, table, branch, info.options()); + this.writeInfo = info; + this.dsSchema = info.schema(); + this.writeSchema = SparkSchemaUtil.convert(table.schema(), dsSchema, writeConf.caseSensitive()); + } + + @Override + public Write build() { + String fileSetId = writeConf.rewrittenFileSetId(); + + Preconditions.checkArgument( + fileSetId != null, "Can only write to %s via actions", table.name()); + + // all files of rewrite group have same partition and spec id + ScanTaskSetManager taskSetManager = ScanTaskSetManager.get(); + List tasks = taskSetManager.fetchTasks(table, fileSetId); + Preconditions.checkArgument( + tasks != null && !tasks.isEmpty(), "No scan tasks found for %s", fileSetId); + + int specId = specId(fileSetId, tasks); + StructLike partition = partition(fileSetId, tasks); + + return new SparkPositionDeletesRewrite( + spark, table, writeConf, writeInfo, writeSchema, dsSchema, specId, partition); + } + + private int specId(String fileSetId, List tasks) { + Set specIds = tasks.stream().map(t -> t.spec().specId()).collect(Collectors.toSet()); + Preconditions.checkArgument( + specIds.size() == 1, + "All scan tasks of %s are expected to have same spec id, but got %s", + fileSetId, + Joiner.on(",").join(specIds)); + return tasks.get(0).spec().specId(); + } + + private StructLike partition(String fileSetId, List tasks) { + StructLikeSet partitions = StructLikeSet.create(tasks.get(0).spec().partitionType()); + tasks.stream().map(ContentScanTask::partition).forEach(partitions::add); + Preconditions.checkArgument( + partitions.size() == 1, + "All scan tasks of %s are expected to have the same partition, but got %s", + fileSetId, + Joiner.on(",").join(partitions)); + return tasks.get(0).partition(); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaOperation.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaOperation.java new file mode 100644 index 000000000000..b970b4ab6c78 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaOperation.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.IsolationLevel; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Table; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.connector.write.DeltaWriteBuilder; +import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.RowLevelOperation; +import org.apache.spark.sql.connector.write.RowLevelOperationInfo; +import org.apache.spark.sql.connector.write.SupportsDelta; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +class SparkPositionDeltaOperation implements RowLevelOperation, SupportsDelta { + + private final SparkSession spark; + private final Table table; + private final String branch; + private final Command command; + private final IsolationLevel isolationLevel; + + // lazy vars + private ScanBuilder lazyScanBuilder; + private Scan configuredScan; + private DeltaWriteBuilder lazyWriteBuilder; + + SparkPositionDeltaOperation( + SparkSession spark, + Table table, + String branch, + RowLevelOperationInfo info, + IsolationLevel isolationLevel) { + this.spark = spark; + this.table = table; + this.branch = branch; + this.command = info.command(); + this.isolationLevel = isolationLevel; + } + + @Override + public Command command() { + return command; + } + + @Override + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { + if (lazyScanBuilder == null) { + this.lazyScanBuilder = + new SparkScanBuilder(spark, table, branch, options) { + @Override + public Scan build() { + Scan scan = super.buildMergeOnReadScan(); + SparkPositionDeltaOperation.this.configuredScan = scan; + return scan; + } + }; + } + + return lazyScanBuilder; + } + + @Override + public DeltaWriteBuilder newWriteBuilder(LogicalWriteInfo info) { + if (lazyWriteBuilder == null) { + // don't validate the scan is not null as if the condition evaluates to false, + // the optimizer replaces the original scan relation with a local relation + lazyWriteBuilder = + new SparkPositionDeltaWriteBuilder( + spark, table, branch, command, configuredScan, isolationLevel, info); + } + + return lazyWriteBuilder; + } + + @Override + public NamedReference[] requiredMetadataAttributes() { + NamedReference specId = Expressions.column(MetadataColumns.SPEC_ID.name()); + NamedReference partition = Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME); + return new NamedReference[] {specId, partition}; + } + + @Override + public NamedReference[] rowId() { + NamedReference file = Expressions.column(MetadataColumns.FILE_PATH.name()); + NamedReference pos = Expressions.column(MetadataColumns.ROW_POSITION.name()); + return new NamedReference[] {file, pos}; + } + + @Override + public boolean representUpdateAsDeleteAndInsert() { + return true; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaWrite.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaWrite.java new file mode 100644 index 000000000000..18020ee935b6 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaWrite.java @@ -0,0 +1,857 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.IsolationLevel.SERIALIZABLE; +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE; +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE; +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPDATE; + +import java.io.IOException; +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.IsolationLevel; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.PartitionKey; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Partitioning; +import org.apache.iceberg.RowDelta; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.SnapshotUpdate; +import org.apache.iceberg.Table; +import org.apache.iceberg.data.BaseDeleteLoader; +import org.apache.iceberg.data.DeleteLoader; +import org.apache.iceberg.deletes.DeleteGranularity; +import org.apache.iceberg.deletes.PositionDelete; +import org.apache.iceberg.deletes.PositionDeleteIndex; +import org.apache.iceberg.encryption.EncryptingFileIO; +import org.apache.iceberg.exceptions.CleanableFailure; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.io.BasePositionDeltaWriter; +import org.apache.iceberg.io.ClusteredDataWriter; +import org.apache.iceberg.io.ClusteredPositionDeleteWriter; +import org.apache.iceberg.io.DataWriteResult; +import org.apache.iceberg.io.DeleteWriteResult; +import org.apache.iceberg.io.FanoutDataWriter; +import org.apache.iceberg.io.FanoutPositionOnlyDeleteWriter; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.OutputFileFactory; +import org.apache.iceberg.io.PartitioningWriter; +import org.apache.iceberg.io.PositionDeltaWriter; +import org.apache.iceberg.io.WriteResult; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.CommitMetadata; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkWriteConf; +import org.apache.iceberg.spark.SparkWriteRequirements; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.CharSequenceSet; +import org.apache.iceberg.util.DeleteFileSet; +import org.apache.iceberg.util.StructProjection; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.distributions.Distribution; +import org.apache.spark.sql.connector.expressions.SortOrder; +import org.apache.spark.sql.connector.write.DeltaBatchWrite; +import org.apache.spark.sql.connector.write.DeltaWrite; +import org.apache.spark.sql.connector.write.DeltaWriter; +import org.apache.spark.sql.connector.write.DeltaWriterFactory; +import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.PhysicalWriteInfo; +import org.apache.spark.sql.connector.write.RequiresDistributionAndOrdering; +import org.apache.spark.sql.connector.write.RowLevelOperation.Command; +import org.apache.spark.sql.connector.write.WriterCommitMessage; +import org.apache.spark.sql.types.StructType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +class SparkPositionDeltaWrite implements DeltaWrite, RequiresDistributionAndOrdering { + + private static final Logger LOG = LoggerFactory.getLogger(SparkPositionDeltaWrite.class); + + private final JavaSparkContext sparkContext; + private final Table table; + private final Command command; + private final SparkBatchQueryScan scan; + private final IsolationLevel isolationLevel; + private final String applicationId; + private final boolean wapEnabled; + private final String wapId; + private final String branch; + private final Map extraSnapshotMetadata; + private final SparkWriteRequirements writeRequirements; + private final Context context; + private final Map writeProperties; + + private boolean cleanupOnAbort = false; + + SparkPositionDeltaWrite( + SparkSession spark, + Table table, + Command command, + SparkBatchQueryScan scan, + IsolationLevel isolationLevel, + SparkWriteConf writeConf, + LogicalWriteInfo info, + Schema dataSchema) { + this.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + this.table = table; + this.command = command; + this.scan = scan; + this.isolationLevel = isolationLevel; + this.applicationId = spark.sparkContext().applicationId(); + this.wapEnabled = writeConf.wapEnabled(); + this.wapId = writeConf.wapId(); + this.branch = writeConf.branch(); + this.extraSnapshotMetadata = writeConf.extraSnapshotMetadata(); + this.writeRequirements = writeConf.positionDeltaRequirements(command); + this.context = new Context(dataSchema, writeConf, info, writeRequirements); + this.writeProperties = writeConf.writeProperties(); + } + + @Override + public Distribution requiredDistribution() { + Distribution distribution = writeRequirements.distribution(); + LOG.info("Requesting {} as write distribution for table {}", distribution, table.name()); + return distribution; + } + + @Override + public boolean distributionStrictlyRequired() { + return false; + } + + @Override + public SortOrder[] requiredOrdering() { + SortOrder[] ordering = writeRequirements.ordering(); + LOG.info("Requesting {} as write ordering for table {}", ordering, table.name()); + return ordering; + } + + @Override + public long advisoryPartitionSizeInBytes() { + long size = writeRequirements.advisoryPartitionSize(); + LOG.info("Requesting {} bytes advisory partition size for table {}", size, table.name()); + return size; + } + + @Override + public DeltaBatchWrite toBatch() { + return new PositionDeltaBatchWrite(); + } + + private class PositionDeltaBatchWrite implements DeltaBatchWrite { + + @Override + public DeltaWriterFactory createBatchWriterFactory(PhysicalWriteInfo info) { + // broadcast large objects since the writer factory will be sent to executors + return new PositionDeltaWriteFactory( + sparkContext.broadcast(SerializableTableWithSize.copyOf(table)), + broadcastRewritableDeletes(), + command, + context, + writeProperties); + } + + private Broadcast> broadcastRewritableDeletes() { + if (context.deleteGranularity() == DeleteGranularity.FILE && scan != null) { + Map rewritableDeletes = scan.rewritableDeletes(); + if (rewritableDeletes != null && !rewritableDeletes.isEmpty()) { + return sparkContext.broadcast(rewritableDeletes); + } + } + return null; + } + + @Override + public boolean useCommitCoordinator() { + return false; + } + + @Override + public void commit(WriterCommitMessage[] messages) { + RowDelta rowDelta = table.newRowDelta(); + + CharSequenceSet referencedDataFiles = CharSequenceSet.empty(); + + int addedDataFilesCount = 0; + int addedDeleteFilesCount = 0; + int removedDeleteFilesCount = 0; + + for (WriterCommitMessage message : messages) { + DeltaTaskCommit taskCommit = (DeltaTaskCommit) message; + + for (DataFile dataFile : taskCommit.dataFiles()) { + rowDelta.addRows(dataFile); + addedDataFilesCount += 1; + } + + for (DeleteFile deleteFile : taskCommit.deleteFiles()) { + rowDelta.addDeletes(deleteFile); + addedDeleteFilesCount += 1; + } + + for (DeleteFile deleteFile : taskCommit.rewrittenDeleteFiles()) { + rowDelta.removeDeletes(deleteFile); + removedDeleteFilesCount += 1; + } + + referencedDataFiles.addAll(Arrays.asList(taskCommit.referencedDataFiles())); + } + + // the scan may be null if the optimizer replaces it with an empty relation + // no validation is needed in this case as the command is independent of the table state + if (scan != null) { + Expression conflictDetectionFilter = conflictDetectionFilter(scan); + rowDelta.conflictDetectionFilter(conflictDetectionFilter); + + rowDelta.validateDataFilesExist(referencedDataFiles); + + if (scan.snapshotId() != null) { + // set the read snapshot ID to check only snapshots that happened after the table was read + // otherwise, the validation will go through all snapshots present in the table + rowDelta.validateFromSnapshot(scan.snapshotId()); + } + + if (command == UPDATE || command == MERGE) { + rowDelta.validateDeletedFiles(); + rowDelta.validateNoConflictingDeleteFiles(); + } + + if (isolationLevel == SERIALIZABLE) { + rowDelta.validateNoConflictingDataFiles(); + } + + String commitMsg = + String.format( + "position delta with %d data files, %d delete files and %d rewritten delete files" + + "(scanSnapshotId: %d, conflictDetectionFilter: %s, isolationLevel: %s)", + addedDataFilesCount, + addedDeleteFilesCount, + removedDeleteFilesCount, + scan.snapshotId(), + conflictDetectionFilter, + isolationLevel); + commitOperation(rowDelta, commitMsg); + + } else { + String commitMsg = + String.format( + "position delta with %d data files and %d delete files (no validation required)", + addedDataFilesCount, addedDeleteFilesCount); + commitOperation(rowDelta, commitMsg); + } + } + + private Expression conflictDetectionFilter(SparkBatchQueryScan queryScan) { + Expression filter = Expressions.alwaysTrue(); + + for (Expression expr : queryScan.filterExpressions()) { + filter = Expressions.and(filter, expr); + } + + return filter; + } + + @Override + public void abort(WriterCommitMessage[] messages) { + if (cleanupOnAbort) { + SparkCleanupUtil.deleteFiles("job abort", table.io(), files(messages)); + } else { + LOG.warn("Skipping cleanup of written files"); + } + } + + private List> files(WriterCommitMessage[] messages) { + List> files = Lists.newArrayList(); + + for (WriterCommitMessage message : messages) { + if (message != null) { + DeltaTaskCommit taskCommit = (DeltaTaskCommit) message; + files.addAll(Arrays.asList(taskCommit.dataFiles())); + files.addAll(Arrays.asList(taskCommit.deleteFiles())); + } + } + + return files; + } + + private void commitOperation(SnapshotUpdate operation, String description) { + LOG.info("Committing {} to table {}", description, table); + if (applicationId != null) { + operation.set("spark.app.id", applicationId); + } + + extraSnapshotMetadata.forEach(operation::set); + + CommitMetadata.commitProperties().forEach(operation::set); + + if (wapEnabled && wapId != null) { + // write-audit-publish is enabled for this table and job + // stage the changes without changing the current snapshot + operation.set(SnapshotSummary.STAGED_WAP_ID_PROP, wapId); + operation.stageOnly(); + } + + if (branch != null) { + operation.toBranch(branch); + } + + try { + long start = System.currentTimeMillis(); + operation.commit(); // abort is automatically called if this fails + long duration = System.currentTimeMillis() - start; + LOG.info("Committed in {} ms", duration); + } catch (Exception e) { + cleanupOnAbort = e instanceof CleanableFailure; + throw e; + } + } + } + + public static class DeltaTaskCommit implements WriterCommitMessage { + private final DataFile[] dataFiles; + private final DeleteFile[] deleteFiles; + private final DeleteFile[] rewrittenDeleteFiles; + private final CharSequence[] referencedDataFiles; + + DeltaTaskCommit(WriteResult result) { + this.dataFiles = result.dataFiles(); + this.deleteFiles = result.deleteFiles(); + this.referencedDataFiles = result.referencedDataFiles(); + this.rewrittenDeleteFiles = result.rewrittenDeleteFiles(); + } + + DeltaTaskCommit(DeleteWriteResult result) { + this.dataFiles = new DataFile[0]; + this.deleteFiles = result.deleteFiles().toArray(new DeleteFile[0]); + this.referencedDataFiles = result.referencedDataFiles().toArray(new CharSequence[0]); + this.rewrittenDeleteFiles = result.rewrittenDeleteFiles().toArray(new DeleteFile[0]); + } + + DataFile[] dataFiles() { + return dataFiles; + } + + DeleteFile[] deleteFiles() { + return deleteFiles; + } + + DeleteFile[] rewrittenDeleteFiles() { + return rewrittenDeleteFiles; + } + + CharSequence[] referencedDataFiles() { + return referencedDataFiles; + } + } + + private static class PositionDeltaWriteFactory implements DeltaWriterFactory { + private final Broadcast

tableBroadcast; + private final Broadcast> rewritableDeletesBroadcast; + private final Command command; + private final Context context; + private final Map writeProperties; + + PositionDeltaWriteFactory( + Broadcast
tableBroadcast, + Broadcast> rewritableDeletesBroadcast, + Command command, + Context context, + Map writeProperties) { + this.tableBroadcast = tableBroadcast; + this.rewritableDeletesBroadcast = rewritableDeletesBroadcast; + this.command = command; + this.context = context; + this.writeProperties = writeProperties; + } + + @Override + public DeltaWriter createWriter(int partitionId, long taskId) { + Table table = tableBroadcast.value(); + + OutputFileFactory dataFileFactory = + OutputFileFactory.builderFor(table, partitionId, taskId) + .format(context.dataFileFormat()) + .operationId(context.queryId()) + .build(); + OutputFileFactory deleteFileFactory = + OutputFileFactory.builderFor(table, partitionId, taskId) + .format(context.deleteFileFormat()) + .operationId(context.queryId()) + .suffix("deletes") + .build(); + + SparkFileWriterFactory writerFactory = + SparkFileWriterFactory.builderFor(table) + .dataFileFormat(context.dataFileFormat()) + .dataSchema(context.dataSchema()) + .dataSparkType(context.dataSparkType()) + .deleteFileFormat(context.deleteFileFormat()) + .positionDeleteSparkType(context.deleteSparkType()) + .writeProperties(writeProperties) + .build(); + + if (command == DELETE) { + return new DeleteOnlyDeltaWriter( + table, rewritableDeletes(), writerFactory, deleteFileFactory, context); + + } else if (table.spec().isUnpartitioned()) { + return new UnpartitionedDeltaWriter( + table, rewritableDeletes(), writerFactory, dataFileFactory, deleteFileFactory, context); + + } else { + return new PartitionedDeltaWriter( + table, rewritableDeletes(), writerFactory, dataFileFactory, deleteFileFactory, context); + } + } + + private Map rewritableDeletes() { + return rewritableDeletesBroadcast != null ? rewritableDeletesBroadcast.getValue() : null; + } + } + + private abstract static class BaseDeltaWriter implements DeltaWriter { + + protected InternalRowWrapper initPartitionRowWrapper(Types.StructType partitionType) { + StructType sparkPartitionType = (StructType) SparkSchemaUtil.convert(partitionType); + return new InternalRowWrapper(sparkPartitionType, partitionType); + } + + protected Map buildPartitionProjections( + Types.StructType partitionType, Map specs) { + Map partitionProjections = Maps.newHashMap(); + + for (int specId : specs.keySet()) { + PartitionSpec spec = specs.get(specId); + StructProjection projection = StructProjection.create(partitionType, spec.partitionType()); + partitionProjections.put(specId, projection); + } + + return partitionProjections; + } + + // use a fanout writer only if enabled and the input is unordered and the table is partitioned + protected PartitioningWriter newDataWriter( + Table table, SparkFileWriterFactory writers, OutputFileFactory files, Context context) { + + FileIO io = table.io(); + boolean useFanoutWriter = context.useFanoutWriter(); + long targetFileSize = context.targetDataFileSize(); + + if (table.spec().isPartitioned() && useFanoutWriter) { + return new FanoutDataWriter<>(writers, files, io, targetFileSize); + } else { + return new ClusteredDataWriter<>(writers, files, io, targetFileSize); + } + } + + // the spec requires position deletes to be ordered by file and position + // use a fanout writer if the input is unordered no matter whether fanout writers are enabled + // clustered writers assume that the position deletes are already ordered by file and position + protected PartitioningWriter, DeleteWriteResult> newDeleteWriter( + Table table, + Map rewritableDeletes, + SparkFileWriterFactory writers, + OutputFileFactory files, + Context context) { + + FileIO io = table.io(); + boolean inputOrdered = context.inputOrdered(); + long targetFileSize = context.targetDeleteFileSize(); + DeleteGranularity deleteGranularity = context.deleteGranularity(); + + if (inputOrdered && rewritableDeletes == null) { + return new ClusteredPositionDeleteWriter<>( + writers, files, io, targetFileSize, deleteGranularity); + } else { + return new FanoutPositionOnlyDeleteWriter<>( + writers, + files, + io, + targetFileSize, + deleteGranularity, + rewritableDeletes != null + ? new PreviousDeleteLoader(table, rewritableDeletes) + : path -> null /* no previous file scoped deletes */); + } + } + } + + private static class PreviousDeleteLoader implements Function { + private final Map deleteFiles; + private final DeleteLoader deleteLoader; + + PreviousDeleteLoader(Table table, Map deleteFiles) { + this.deleteFiles = deleteFiles; + this.deleteLoader = + new BaseDeleteLoader( + deleteFile -> + EncryptingFileIO.combine(table.io(), table.encryption()) + .newInputFile(deleteFile)); + } + + @Override + public PositionDeleteIndex apply(CharSequence path) { + DeleteFileSet deleteFileSet = deleteFiles.get(path.toString()); + if (deleteFileSet == null) { + return null; + } + + return deleteLoader.loadPositionDeletes(deleteFileSet, path); + } + } + + private static class DeleteOnlyDeltaWriter extends BaseDeltaWriter { + private final PartitioningWriter, DeleteWriteResult> delegate; + private final PositionDelete positionDelete; + private final FileIO io; + private final Map specs; + private final InternalRowWrapper partitionRowWrapper; + private final Map partitionProjections; + private final int specIdOrdinal; + private final int partitionOrdinal; + private final int fileOrdinal; + private final int positionOrdinal; + + private boolean closed = false; + + DeleteOnlyDeltaWriter( + Table table, + Map rewritableDeletes, + SparkFileWriterFactory writerFactory, + OutputFileFactory deleteFileFactory, + Context context) { + + this.delegate = + newDeleteWriter(table, rewritableDeletes, writerFactory, deleteFileFactory, context); + this.positionDelete = PositionDelete.create(); + this.io = table.io(); + this.specs = table.specs(); + + Types.StructType partitionType = Partitioning.partitionType(table); + this.partitionRowWrapper = initPartitionRowWrapper(partitionType); + this.partitionProjections = buildPartitionProjections(partitionType, specs); + + this.specIdOrdinal = context.specIdOrdinal(); + this.partitionOrdinal = context.partitionOrdinal(); + this.fileOrdinal = context.fileOrdinal(); + this.positionOrdinal = context.positionOrdinal(); + } + + @Override + public void delete(InternalRow metadata, InternalRow id) throws IOException { + int specId = metadata.getInt(specIdOrdinal); + PartitionSpec spec = specs.get(specId); + + InternalRow partition = metadata.getStruct(partitionOrdinal, partitionRowWrapper.size()); + StructProjection partitionProjection = partitionProjections.get(specId); + partitionProjection.wrap(partitionRowWrapper.wrap(partition)); + + String file = id.getString(fileOrdinal); + long position = id.getLong(positionOrdinal); + positionDelete.set(file, position, null); + delegate.write(positionDelete, spec, partitionProjection); + } + + @Override + public void update(InternalRow metadata, InternalRow id, InternalRow row) { + throw new UnsupportedOperationException( + this.getClass().getName() + " does not implement update"); + } + + @Override + public void insert(InternalRow row) throws IOException { + throw new UnsupportedOperationException( + this.getClass().getName() + " does not implement insert"); + } + + @Override + public WriterCommitMessage commit() throws IOException { + close(); + + DeleteWriteResult result = delegate.result(); + return new DeltaTaskCommit(result); + } + + @Override + public void abort() throws IOException { + close(); + + DeleteWriteResult result = delegate.result(); + SparkCleanupUtil.deleteTaskFiles(io, result.deleteFiles()); + } + + @Override + public void close() throws IOException { + if (!closed) { + delegate.close(); + this.closed = true; + } + } + } + + @SuppressWarnings("checkstyle:VisibilityModifier") + private abstract static class DeleteAndDataDeltaWriter extends BaseDeltaWriter { + protected final PositionDeltaWriter delegate; + private final FileIO io; + private final Map specs; + private final InternalRowWrapper deletePartitionRowWrapper; + private final Map deletePartitionProjections; + private final int specIdOrdinal; + private final int partitionOrdinal; + private final int fileOrdinal; + private final int positionOrdinal; + + private boolean closed = false; + + DeleteAndDataDeltaWriter( + Table table, + Map rewritableDeletes, + SparkFileWriterFactory writerFactory, + OutputFileFactory dataFileFactory, + OutputFileFactory deleteFileFactory, + Context context) { + this.delegate = + new BasePositionDeltaWriter<>( + newDataWriter(table, writerFactory, dataFileFactory, context), + newDeleteWriter(table, rewritableDeletes, writerFactory, deleteFileFactory, context)); + this.io = table.io(); + this.specs = table.specs(); + + Types.StructType partitionType = Partitioning.partitionType(table); + this.deletePartitionRowWrapper = initPartitionRowWrapper(partitionType); + this.deletePartitionProjections = buildPartitionProjections(partitionType, specs); + + this.specIdOrdinal = context.specIdOrdinal(); + this.partitionOrdinal = context.partitionOrdinal(); + this.fileOrdinal = context.fileOrdinal(); + this.positionOrdinal = context.positionOrdinal(); + } + + @Override + public void delete(InternalRow meta, InternalRow id) throws IOException { + int specId = meta.getInt(specIdOrdinal); + PartitionSpec spec = specs.get(specId); + + InternalRow partition = meta.getStruct(partitionOrdinal, deletePartitionRowWrapper.size()); + StructProjection partitionProjection = deletePartitionProjections.get(specId); + partitionProjection.wrap(deletePartitionRowWrapper.wrap(partition)); + + String file = id.getString(fileOrdinal); + long position = id.getLong(positionOrdinal); + delegate.delete(file, position, spec, partitionProjection); + } + + @Override + public WriterCommitMessage commit() throws IOException { + close(); + + WriteResult result = delegate.result(); + return new DeltaTaskCommit(result); + } + + @Override + public void abort() throws IOException { + close(); + + WriteResult result = delegate.result(); + SparkCleanupUtil.deleteTaskFiles(io, files(result)); + } + + private List> files(WriteResult result) { + List> files = Lists.newArrayList(); + files.addAll(Arrays.asList(result.dataFiles())); + files.addAll(Arrays.asList(result.deleteFiles())); + return files; + } + + @Override + public void close() throws IOException { + if (!closed) { + delegate.close(); + this.closed = true; + } + } + } + + private static class UnpartitionedDeltaWriter extends DeleteAndDataDeltaWriter { + private final PartitionSpec dataSpec; + + UnpartitionedDeltaWriter( + Table table, + Map rewritableDeletes, + SparkFileWriterFactory writerFactory, + OutputFileFactory dataFileFactory, + OutputFileFactory deleteFileFactory, + Context context) { + super(table, rewritableDeletes, writerFactory, dataFileFactory, deleteFileFactory, context); + this.dataSpec = table.spec(); + } + + @Override + public void update(InternalRow meta, InternalRow id, InternalRow row) throws IOException { + throw new UnsupportedOperationException("Update must be represented as delete and insert"); + } + + @Override + public void insert(InternalRow row) throws IOException { + delegate.insert(row, dataSpec, null); + } + } + + private static class PartitionedDeltaWriter extends DeleteAndDataDeltaWriter { + private final PartitionSpec dataSpec; + private final PartitionKey dataPartitionKey; + private final InternalRowWrapper internalRowDataWrapper; + + PartitionedDeltaWriter( + Table table, + Map rewritableDeletes, + SparkFileWriterFactory writerFactory, + OutputFileFactory dataFileFactory, + OutputFileFactory deleteFileFactory, + Context context) { + super(table, rewritableDeletes, writerFactory, dataFileFactory, deleteFileFactory, context); + + this.dataSpec = table.spec(); + this.dataPartitionKey = new PartitionKey(dataSpec, context.dataSchema()); + this.internalRowDataWrapper = + new InternalRowWrapper(context.dataSparkType(), context.dataSchema().asStruct()); + } + + @Override + public void update(InternalRow meta, InternalRow id, InternalRow row) throws IOException { + throw new UnsupportedOperationException("Update must be represented as delete and insert"); + } + + @Override + public void insert(InternalRow row) throws IOException { + dataPartitionKey.partition(internalRowDataWrapper.wrap(row)); + delegate.insert(row, dataSpec, dataPartitionKey); + } + } + + // a serializable helper class for common parameters required to configure writers + private static class Context implements Serializable { + private final Schema dataSchema; + private final StructType dataSparkType; + private final FileFormat dataFileFormat; + private final long targetDataFileSize; + private final StructType deleteSparkType; + private final StructType metadataSparkType; + private final FileFormat deleteFileFormat; + private final long targetDeleteFileSize; + private final DeleteGranularity deleteGranularity; + private final String queryId; + private final boolean useFanoutWriter; + private final boolean inputOrdered; + + Context( + Schema dataSchema, + SparkWriteConf writeConf, + LogicalWriteInfo info, + SparkWriteRequirements writeRequirements) { + this.dataSchema = dataSchema; + this.dataSparkType = info.schema(); + this.dataFileFormat = writeConf.dataFileFormat(); + this.targetDataFileSize = writeConf.targetDataFileSize(); + this.deleteSparkType = info.rowIdSchema().get(); + this.deleteFileFormat = writeConf.deleteFileFormat(); + this.targetDeleteFileSize = writeConf.targetDeleteFileSize(); + this.deleteGranularity = writeConf.deleteGranularity(); + this.metadataSparkType = info.metadataSchema().get(); + this.queryId = info.queryId(); + this.useFanoutWriter = writeConf.useFanoutWriter(writeRequirements); + this.inputOrdered = writeRequirements.hasOrdering(); + } + + Schema dataSchema() { + return dataSchema; + } + + StructType dataSparkType() { + return dataSparkType; + } + + FileFormat dataFileFormat() { + return dataFileFormat; + } + + long targetDataFileSize() { + return targetDataFileSize; + } + + StructType deleteSparkType() { + return deleteSparkType; + } + + FileFormat deleteFileFormat() { + return deleteFileFormat; + } + + long targetDeleteFileSize() { + return targetDeleteFileSize; + } + + DeleteGranularity deleteGranularity() { + return deleteGranularity; + } + + String queryId() { + return queryId; + } + + boolean useFanoutWriter() { + return useFanoutWriter; + } + + boolean inputOrdered() { + return inputOrdered; + } + + int specIdOrdinal() { + return metadataSparkType.fieldIndex(MetadataColumns.SPEC_ID.name()); + } + + int partitionOrdinal() { + return metadataSparkType.fieldIndex(MetadataColumns.PARTITION_COLUMN_NAME); + } + + int fileOrdinal() { + return deleteSparkType.fieldIndex(MetadataColumns.FILE_PATH.name()); + } + + int positionOrdinal() { + return deleteSparkType.fieldIndex(MetadataColumns.ROW_POSITION.name()); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaWriteBuilder.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaWriteBuilder.java new file mode 100644 index 000000000000..c58935206bf2 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaWriteBuilder.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.IsolationLevel; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkUtil; +import org.apache.iceberg.spark.SparkWriteConf; +import org.apache.iceberg.types.TypeUtil; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.write.DeltaWrite; +import org.apache.spark.sql.connector.write.DeltaWriteBuilder; +import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.RowLevelOperation.Command; +import org.apache.spark.sql.types.StructType; + +class SparkPositionDeltaWriteBuilder implements DeltaWriteBuilder { + + private static final Schema EXPECTED_ROW_ID_SCHEMA = + new Schema(MetadataColumns.FILE_PATH, MetadataColumns.ROW_POSITION); + + private final SparkSession spark; + private final Table table; + private final Command command; + private final SparkBatchQueryScan scan; + private final IsolationLevel isolationLevel; + private final SparkWriteConf writeConf; + private final LogicalWriteInfo info; + private final boolean checkNullability; + private final boolean checkOrdering; + + SparkPositionDeltaWriteBuilder( + SparkSession spark, + Table table, + String branch, + Command command, + Scan scan, + IsolationLevel isolationLevel, + LogicalWriteInfo info) { + this.spark = spark; + this.table = table; + this.command = command; + this.scan = (SparkBatchQueryScan) scan; + this.isolationLevel = isolationLevel; + this.writeConf = new SparkWriteConf(spark, table, branch, info.options()); + this.info = info; + this.checkNullability = writeConf.checkNullability(); + this.checkOrdering = writeConf.checkOrdering(); + } + + @Override + public DeltaWrite build() { + Schema dataSchema = dataSchema(); + + validateRowIdSchema(); + validateMetadataSchema(); + SparkUtil.validatePartitionTransforms(table.spec()); + + return new SparkPositionDeltaWrite( + spark, table, command, scan, isolationLevel, writeConf, info, dataSchema); + } + + private Schema dataSchema() { + if (info.schema() == null || info.schema().isEmpty()) { + return null; + } else { + Schema dataSchema = SparkSchemaUtil.convert(table.schema(), info.schema()); + validateSchema("data", table.schema(), dataSchema); + return dataSchema; + } + } + + private void validateRowIdSchema() { + Preconditions.checkArgument(info.rowIdSchema().isPresent(), "Row ID schema must be set"); + StructType rowIdSparkType = info.rowIdSchema().get(); + Schema rowIdSchema = SparkSchemaUtil.convert(EXPECTED_ROW_ID_SCHEMA, rowIdSparkType); + validateSchema("row ID", EXPECTED_ROW_ID_SCHEMA, rowIdSchema); + } + + private void validateMetadataSchema() { + Preconditions.checkArgument(info.metadataSchema().isPresent(), "Metadata schema must be set"); + Schema expectedMetadataSchema = + new Schema( + MetadataColumns.SPEC_ID, + MetadataColumns.metadataColumn(table, MetadataColumns.PARTITION_COLUMN_NAME)); + StructType metadataSparkType = info.metadataSchema().get(); + Schema metadataSchema = SparkSchemaUtil.convert(expectedMetadataSchema, metadataSparkType); + validateSchema("metadata", expectedMetadataSchema, metadataSchema); + } + + private void validateSchema(String context, Schema expected, Schema actual) { + TypeUtil.validateSchema(context, expected, actual, checkNullability, checkOrdering); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkRowLevelOperationBuilder.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkRowLevelOperationBuilder.java new file mode 100644 index 000000000000..b113bd9b25af --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkRowLevelOperationBuilder.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.TableProperties.DELETE_ISOLATION_LEVEL; +import static org.apache.iceberg.TableProperties.DELETE_ISOLATION_LEVEL_DEFAULT; +import static org.apache.iceberg.TableProperties.DELETE_MODE; +import static org.apache.iceberg.TableProperties.DELETE_MODE_DEFAULT; +import static org.apache.iceberg.TableProperties.MERGE_ISOLATION_LEVEL; +import static org.apache.iceberg.TableProperties.MERGE_ISOLATION_LEVEL_DEFAULT; +import static org.apache.iceberg.TableProperties.MERGE_MODE; +import static org.apache.iceberg.TableProperties.MERGE_MODE_DEFAULT; +import static org.apache.iceberg.TableProperties.UPDATE_ISOLATION_LEVEL; +import static org.apache.iceberg.TableProperties.UPDATE_ISOLATION_LEVEL_DEFAULT; +import static org.apache.iceberg.TableProperties.UPDATE_MODE; +import static org.apache.iceberg.TableProperties.UPDATE_MODE_DEFAULT; + +import java.util.Map; +import org.apache.iceberg.IsolationLevel; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.Table; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.write.RowLevelOperation; +import org.apache.spark.sql.connector.write.RowLevelOperation.Command; +import org.apache.spark.sql.connector.write.RowLevelOperationBuilder; +import org.apache.spark.sql.connector.write.RowLevelOperationInfo; + +class SparkRowLevelOperationBuilder implements RowLevelOperationBuilder { + + private final SparkSession spark; + private final Table table; + private final String branch; + private final RowLevelOperationInfo info; + private final RowLevelOperationMode mode; + private final IsolationLevel isolationLevel; + + SparkRowLevelOperationBuilder( + SparkSession spark, Table table, String branch, RowLevelOperationInfo info) { + this.spark = spark; + this.table = table; + this.branch = branch; + this.info = info; + this.mode = mode(table.properties(), info.command()); + this.isolationLevel = isolationLevel(table.properties(), info.command()); + } + + @Override + public RowLevelOperation build() { + switch (mode) { + case COPY_ON_WRITE: + return new SparkCopyOnWriteOperation(spark, table, branch, info, isolationLevel); + case MERGE_ON_READ: + return new SparkPositionDeltaOperation(spark, table, branch, info, isolationLevel); + default: + throw new IllegalArgumentException("Unsupported operation mode: " + mode); + } + } + + private RowLevelOperationMode mode(Map properties, Command command) { + String modeName; + + switch (command) { + case DELETE: + modeName = properties.getOrDefault(DELETE_MODE, DELETE_MODE_DEFAULT); + break; + case UPDATE: + modeName = properties.getOrDefault(UPDATE_MODE, UPDATE_MODE_DEFAULT); + break; + case MERGE: + modeName = properties.getOrDefault(MERGE_MODE, MERGE_MODE_DEFAULT); + break; + default: + throw new IllegalArgumentException("Unsupported command: " + command); + } + + return RowLevelOperationMode.fromName(modeName); + } + + private IsolationLevel isolationLevel(Map properties, Command command) { + String levelName; + + switch (command) { + case DELETE: + levelName = properties.getOrDefault(DELETE_ISOLATION_LEVEL, DELETE_ISOLATION_LEVEL_DEFAULT); + break; + case UPDATE: + levelName = properties.getOrDefault(UPDATE_ISOLATION_LEVEL, UPDATE_ISOLATION_LEVEL_DEFAULT); + break; + case MERGE: + levelName = properties.getOrDefault(MERGE_ISOLATION_LEVEL, MERGE_ISOLATION_LEVEL_DEFAULT); + break; + default: + throw new IllegalArgumentException("Unsupported command: " + command); + } + + return IsolationLevel.fromName(levelName); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkRowReaderFactory.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkRowReaderFactory.java new file mode 100644 index 000000000000..23699aeb167c --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkRowReaderFactory.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.ChangelogScanTask; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.PositionDeletesScanTask; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.sql.connector.read.PartitionReaderFactory; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +class SparkRowReaderFactory implements PartitionReaderFactory { + + SparkRowReaderFactory() {} + + @Override + public PartitionReader createReader(InputPartition inputPartition) { + Preconditions.checkArgument( + inputPartition instanceof SparkInputPartition, + "Unknown input partition type: %s", + inputPartition.getClass().getName()); + + SparkInputPartition partition = (SparkInputPartition) inputPartition; + + if (partition.allTasksOfType(FileScanTask.class)) { + return new RowDataReader(partition); + + } else if (partition.allTasksOfType(ChangelogScanTask.class)) { + return new ChangelogRowReader(partition); + + } else if (partition.allTasksOfType(PositionDeletesScanTask.class)) { + return new PositionDeletesRowReader(partition); + + } else { + throw new UnsupportedOperationException( + "Unsupported task group for row-based reads: " + partition.taskGroup()); + } + } + + @Override + public PartitionReader createColumnarReader(InputPartition inputPartition) { + throw new UnsupportedOperationException("Columnar reads are not supported"); + } + + @Override + public boolean supportColumnarReads(InputPartition inputPartition) { + return false; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkScan.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkScan.java new file mode 100644 index 000000000000..019f3919dc57 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkScan.java @@ -0,0 +1,351 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import org.apache.iceberg.BlobMetadata; +import org.apache.iceberg.ScanTask; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.StatisticsFile; +import org.apache.iceberg.Table; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.metrics.ScanReport; +import org.apache.iceberg.relocated.com.google.common.base.Strings; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.source.metrics.EqualityDeleteFiles; +import org.apache.iceberg.spark.source.metrics.IndexedDeleteFiles; +import org.apache.iceberg.spark.source.metrics.NumDeletes; +import org.apache.iceberg.spark.source.metrics.NumSplits; +import org.apache.iceberg.spark.source.metrics.PositionalDeleteFiles; +import org.apache.iceberg.spark.source.metrics.ResultDataFiles; +import org.apache.iceberg.spark.source.metrics.ResultDeleteFiles; +import org.apache.iceberg.spark.source.metrics.ScannedDataManifests; +import org.apache.iceberg.spark.source.metrics.ScannedDeleteManifests; +import org.apache.iceberg.spark.source.metrics.SkippedDataFiles; +import org.apache.iceberg.spark.source.metrics.SkippedDataManifests; +import org.apache.iceberg.spark.source.metrics.SkippedDeleteFiles; +import org.apache.iceberg.spark.source.metrics.SkippedDeleteManifests; +import org.apache.iceberg.spark.source.metrics.TaskEqualityDeleteFiles; +import org.apache.iceberg.spark.source.metrics.TaskIndexedDeleteFiles; +import org.apache.iceberg.spark.source.metrics.TaskPositionalDeleteFiles; +import org.apache.iceberg.spark.source.metrics.TaskResultDataFiles; +import org.apache.iceberg.spark.source.metrics.TaskResultDeleteFiles; +import org.apache.iceberg.spark.source.metrics.TaskScannedDataManifests; +import org.apache.iceberg.spark.source.metrics.TaskScannedDeleteManifests; +import org.apache.iceberg.spark.source.metrics.TaskSkippedDataFiles; +import org.apache.iceberg.spark.source.metrics.TaskSkippedDataManifests; +import org.apache.iceberg.spark.source.metrics.TaskSkippedDeleteFiles; +import org.apache.iceberg.spark.source.metrics.TaskSkippedDeleteManifests; +import org.apache.iceberg.spark.source.metrics.TaskTotalDataFileSize; +import org.apache.iceberg.spark.source.metrics.TaskTotalDataManifests; +import org.apache.iceberg.spark.source.metrics.TaskTotalDeleteFileSize; +import org.apache.iceberg.spark.source.metrics.TaskTotalDeleteManifests; +import org.apache.iceberg.spark.source.metrics.TaskTotalPlanningDuration; +import org.apache.iceberg.spark.source.metrics.TotalDataFileSize; +import org.apache.iceberg.spark.source.metrics.TotalDataManifests; +import org.apache.iceberg.spark.source.metrics.TotalDeleteFileSize; +import org.apache.iceberg.spark.source.metrics.TotalDeleteManifests; +import org.apache.iceberg.spark.source.metrics.TotalPlanningDuration; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.iceberg.util.TableScanUtil; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.expressions.FieldReference; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.metric.CustomMetric; +import org.apache.spark.sql.connector.metric.CustomTaskMetric; +import org.apache.spark.sql.connector.read.Batch; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.Statistics; +import org.apache.spark.sql.connector.read.SupportsReportStatistics; +import org.apache.spark.sql.connector.read.colstats.ColumnStatistics; +import org.apache.spark.sql.connector.read.streaming.MicroBatchStream; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.types.StructType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +abstract class SparkScan implements Scan, SupportsReportStatistics { + private static final Logger LOG = LoggerFactory.getLogger(SparkScan.class); + private static final String NDV_KEY = "ndv"; + + private final JavaSparkContext sparkContext; + private final Table table; + private final SparkSession spark; + private final SparkReadConf readConf; + private final boolean caseSensitive; + private final Schema expectedSchema; + private final List filterExpressions; + private final String branch; + private final Supplier scanReportSupplier; + + // lazy variables + private StructType readSchema; + + SparkScan( + SparkSession spark, + Table table, + SparkReadConf readConf, + Schema expectedSchema, + List filters, + Supplier scanReportSupplier) { + Schema snapshotSchema = SnapshotUtil.schemaFor(table, readConf.branch()); + SparkSchemaUtil.validateMetadataColumnReferences(snapshotSchema, expectedSchema); + + this.spark = spark; + this.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + this.table = table; + this.readConf = readConf; + this.caseSensitive = readConf.caseSensitive(); + this.expectedSchema = expectedSchema; + this.filterExpressions = filters != null ? filters : Collections.emptyList(); + this.branch = readConf.branch(); + this.scanReportSupplier = scanReportSupplier; + } + + protected Table table() { + return table; + } + + protected String branch() { + return branch; + } + + protected boolean caseSensitive() { + return caseSensitive; + } + + protected Schema expectedSchema() { + return expectedSchema; + } + + protected List filterExpressions() { + return filterExpressions; + } + + protected Types.StructType groupingKeyType() { + return Types.StructType.of(); + } + + protected abstract List> taskGroups(); + + @Override + public Batch toBatch() { + return new SparkBatch( + sparkContext, table, readConf, groupingKeyType(), taskGroups(), expectedSchema, hashCode()); + } + + @Override + public MicroBatchStream toMicroBatchStream(String checkpointLocation) { + return new SparkMicroBatchStream( + sparkContext, table, readConf, expectedSchema, checkpointLocation); + } + + @Override + public StructType readSchema() { + if (readSchema == null) { + this.readSchema = SparkSchemaUtil.convert(expectedSchema); + } + return readSchema; + } + + @Override + public Statistics estimateStatistics() { + return estimateStatistics(SnapshotUtil.latestSnapshot(table, branch)); + } + + protected Statistics estimateStatistics(Snapshot snapshot) { + // its a fresh table, no data + if (snapshot == null) { + return new Stats(0L, 0L, Collections.emptyMap()); + } + + boolean cboEnabled = + Boolean.parseBoolean(spark.conf().get(SQLConf.CBO_ENABLED().key(), "false")); + Map colStatsMap = Collections.emptyMap(); + if (readConf.reportColumnStats() && cboEnabled) { + colStatsMap = Maps.newHashMap(); + List files = table.statisticsFiles(); + if (!files.isEmpty()) { + List metadataList = (files.get(0)).blobMetadata(); + + Map> groupedByField = + metadataList.stream() + .collect( + Collectors.groupingBy( + metadata -> metadata.fields().get(0), Collectors.toList())); + + for (Map.Entry> entry : groupedByField.entrySet()) { + String colName = table.schema().findColumnName(entry.getKey()); + NamedReference ref = FieldReference.column(colName); + Long ndv = null; + + for (BlobMetadata blobMetadata : entry.getValue()) { + if (blobMetadata + .type() + .equals(org.apache.iceberg.puffin.StandardBlobTypes.APACHE_DATASKETCHES_THETA_V1)) { + String ndvStr = blobMetadata.properties().get(NDV_KEY); + if (!Strings.isNullOrEmpty(ndvStr)) { + ndv = Long.parseLong(ndvStr); + } else { + LOG.debug("{} is not set in BlobMetadata for column {}", NDV_KEY, colName); + } + } else { + LOG.debug("Blob type {} is not supported yet", blobMetadata.type()); + } + } + ColumnStatistics colStats = + new SparkColumnStatistics(ndv, null, null, null, null, null, null); + + colStatsMap.put(ref, colStats); + } + } + } + + // estimate stats using snapshot summary only for partitioned tables + // (metadata tables are unpartitioned) + if (!table.spec().isUnpartitioned() && filterExpressions.isEmpty()) { + LOG.debug( + "Using snapshot {} metadata to estimate statistics for table {}", + snapshot.snapshotId(), + table.name()); + long totalRecords = totalRecords(snapshot); + return new Stats( + SparkSchemaUtil.estimateSize(readSchema(), totalRecords), totalRecords, colStatsMap); + } + + long rowsCount = taskGroups().stream().mapToLong(ScanTaskGroup::estimatedRowsCount).sum(); + long sizeInBytes = SparkSchemaUtil.estimateSize(readSchema(), rowsCount); + return new Stats(sizeInBytes, rowsCount, colStatsMap); + } + + private long totalRecords(Snapshot snapshot) { + Map summary = snapshot.summary(); + return PropertyUtil.propertyAsLong(summary, SnapshotSummary.TOTAL_RECORDS_PROP, Long.MAX_VALUE); + } + + @Override + public String description() { + String groupingKeyFieldNamesAsString = + groupingKeyType().fields().stream() + .map(Types.NestedField::name) + .collect(Collectors.joining(", ")); + + return String.format( + "%s (branch=%s) [filters=%s, groupedBy=%s]", + table(), branch(), Spark3Util.describe(filterExpressions), groupingKeyFieldNamesAsString); + } + + @Override + public CustomTaskMetric[] reportDriverMetrics() { + ScanReport scanReport = scanReportSupplier != null ? scanReportSupplier.get() : null; + + if (scanReport == null) { + return new CustomTaskMetric[0]; + } + + List driverMetrics = Lists.newArrayList(); + + // common + driverMetrics.add(TaskTotalPlanningDuration.from(scanReport)); + + // data manifests + driverMetrics.add(TaskTotalDataManifests.from(scanReport)); + driverMetrics.add(TaskScannedDataManifests.from(scanReport)); + driverMetrics.add(TaskSkippedDataManifests.from(scanReport)); + + // data files + driverMetrics.add(TaskResultDataFiles.from(scanReport)); + driverMetrics.add(TaskSkippedDataFiles.from(scanReport)); + driverMetrics.add(TaskTotalDataFileSize.from(scanReport)); + + // delete manifests + driverMetrics.add(TaskTotalDeleteManifests.from(scanReport)); + driverMetrics.add(TaskScannedDeleteManifests.from(scanReport)); + driverMetrics.add(TaskSkippedDeleteManifests.from(scanReport)); + + // delete files + driverMetrics.add(TaskTotalDeleteFileSize.from(scanReport)); + driverMetrics.add(TaskResultDeleteFiles.from(scanReport)); + driverMetrics.add(TaskEqualityDeleteFiles.from(scanReport)); + driverMetrics.add(TaskIndexedDeleteFiles.from(scanReport)); + driverMetrics.add(TaskPositionalDeleteFiles.from(scanReport)); + driverMetrics.add(TaskSkippedDeleteFiles.from(scanReport)); + + return driverMetrics.toArray(new CustomTaskMetric[0]); + } + + @Override + public CustomMetric[] supportedCustomMetrics() { + return new CustomMetric[] { + // task metrics + new NumSplits(), + new NumDeletes(), + + // common + new TotalPlanningDuration(), + + // data manifests + new TotalDataManifests(), + new ScannedDataManifests(), + new SkippedDataManifests(), + + // data files + new ResultDataFiles(), + new SkippedDataFiles(), + new TotalDataFileSize(), + + // delete manifests + new TotalDeleteManifests(), + new ScannedDeleteManifests(), + new SkippedDeleteManifests(), + + // delete files + new TotalDeleteFileSize(), + new ResultDeleteFiles(), + new EqualityDeleteFiles(), + new IndexedDeleteFiles(), + new PositionalDeleteFiles(), + new SkippedDeleteFiles() + }; + } + + protected long adjustSplitSize(List tasks, long splitSize) { + if (readConf.splitSizeOption() == null && readConf.adaptiveSplitSizeEnabled()) { + long scanSize = tasks.stream().mapToLong(ScanTask::sizeBytes).sum(); + int parallelism = readConf.parallelism(); + return TableScanUtil.adjustSplitSize(scanSize, parallelism, splitSize); + } else { + return splitSize; + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java new file mode 100644 index 000000000000..d511fefd8ae0 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java @@ -0,0 +1,761 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.BatchScan; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.IncrementalAppendScan; +import org.apache.iceberg.IncrementalChangelogScan; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.MetricsConfig; +import org.apache.iceberg.MetricsModes; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SparkDistributedDataScan; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.expressions.AggregateEvaluator; +import org.apache.iceberg.expressions.Binder; +import org.apache.iceberg.expressions.BoundAggregate; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.ExpressionUtil; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.metrics.InMemoryMetricsReporter; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkAggregates; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkV2Filters; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc; +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation; +import org.apache.spark.sql.connector.expressions.filter.Predicate; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.connector.read.Statistics; +import org.apache.spark.sql.connector.read.SupportsPushDownAggregates; +import org.apache.spark.sql.connector.read.SupportsPushDownRequiredColumns; +import org.apache.spark.sql.connector.read.SupportsPushDownV2Filters; +import org.apache.spark.sql.connector.read.SupportsReportStatistics; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class SparkScanBuilder + implements ScanBuilder, + SupportsPushDownAggregates, + SupportsPushDownV2Filters, + SupportsPushDownRequiredColumns, + SupportsReportStatistics { + + private static final Logger LOG = LoggerFactory.getLogger(SparkScanBuilder.class); + private static final Predicate[] NO_PREDICATES = new Predicate[0]; + private StructType pushedAggregateSchema; + private Scan localScan; + + private final SparkSession spark; + private final Table table; + private final CaseInsensitiveStringMap options; + private final SparkReadConf readConf; + private final List metaColumns = Lists.newArrayList(); + private final InMemoryMetricsReporter metricsReporter; + + private Schema schema; + private boolean caseSensitive; + private List filterExpressions = null; + private Predicate[] pushedPredicates = NO_PREDICATES; + + SparkScanBuilder( + SparkSession spark, + Table table, + String branch, + Schema schema, + CaseInsensitiveStringMap options) { + this.spark = spark; + this.table = table; + this.schema = schema; + this.options = options; + this.readConf = new SparkReadConf(spark, table, branch, options); + this.caseSensitive = readConf.caseSensitive(); + this.metricsReporter = new InMemoryMetricsReporter(); + } + + SparkScanBuilder(SparkSession spark, Table table, CaseInsensitiveStringMap options) { + this(spark, table, table.schema(), options); + } + + SparkScanBuilder( + SparkSession spark, Table table, String branch, CaseInsensitiveStringMap options) { + this(spark, table, branch, SnapshotUtil.schemaFor(table, branch), options); + } + + SparkScanBuilder( + SparkSession spark, Table table, Schema schema, CaseInsensitiveStringMap options) { + this(spark, table, null, schema, options); + } + + private Expression filterExpression() { + if (filterExpressions != null) { + return filterExpressions.stream().reduce(Expressions.alwaysTrue(), Expressions::and); + } + return Expressions.alwaysTrue(); + } + + public SparkScanBuilder caseSensitive(boolean isCaseSensitive) { + this.caseSensitive = isCaseSensitive; + return this; + } + + @Override + public Predicate[] pushPredicates(Predicate[] predicates) { + // there are 3 kinds of filters: + // (1) filters that can be pushed down completely and don't have to evaluated by Spark + // (e.g. filters that select entire partitions) + // (2) filters that can be pushed down partially and require record-level filtering in Spark + // (e.g. filters that may select some but not necessarily all rows in a file) + // (3) filters that can't be pushed down at all and have to be evaluated by Spark + // (e.g. unsupported filters) + // filters (1) and (2) are used prune files during job planning in Iceberg + // filters (2) and (3) form a set of post scan filters and must be evaluated by Spark + + List expressions = Lists.newArrayListWithExpectedSize(predicates.length); + List pushableFilters = Lists.newArrayListWithExpectedSize(predicates.length); + List postScanFilters = Lists.newArrayListWithExpectedSize(predicates.length); + + for (Predicate predicate : predicates) { + try { + Expression expr = SparkV2Filters.convert(predicate); + + if (expr != null) { + // try binding the expression to ensure it can be pushed down + Binder.bind(schema.asStruct(), expr, caseSensitive); + expressions.add(expr); + pushableFilters.add(predicate); + } + + if (expr == null + || unpartitioned() + || !ExpressionUtil.selectsPartitions(expr, table, caseSensitive)) { + postScanFilters.add(predicate); + } else { + LOG.info("Evaluating completely on Iceberg side: {}", predicate); + } + + } catch (Exception e) { + LOG.warn("Failed to check if {} can be pushed down: {}", predicate, e.getMessage()); + postScanFilters.add(predicate); + } + } + + this.filterExpressions = expressions; + this.pushedPredicates = pushableFilters.toArray(new Predicate[0]); + + return postScanFilters.toArray(new Predicate[0]); + } + + private boolean unpartitioned() { + return table.specs().values().stream().noneMatch(PartitionSpec::isPartitioned); + } + + @Override + public Predicate[] pushedPredicates() { + return pushedPredicates; + } + + @Override + public boolean pushAggregation(Aggregation aggregation) { + if (!canPushDownAggregation(aggregation)) { + return false; + } + + AggregateEvaluator aggregateEvaluator; + List> expressions = + Lists.newArrayListWithExpectedSize(aggregation.aggregateExpressions().length); + + for (AggregateFunc aggregateFunc : aggregation.aggregateExpressions()) { + try { + Expression expr = SparkAggregates.convert(aggregateFunc); + if (expr != null) { + Expression bound = Binder.bind(schema.asStruct(), expr, caseSensitive); + expressions.add((BoundAggregate) bound); + } else { + LOG.info( + "Skipping aggregate pushdown: AggregateFunc {} can't be converted to iceberg expression", + aggregateFunc); + return false; + } + } catch (IllegalArgumentException e) { + LOG.info("Skipping aggregate pushdown: Bind failed for AggregateFunc {}", aggregateFunc, e); + return false; + } + } + + aggregateEvaluator = AggregateEvaluator.create(expressions); + + if (!metricsModeSupportsAggregatePushDown(aggregateEvaluator.aggregates())) { + return false; + } + + org.apache.iceberg.Scan scan = + buildIcebergBatchScan(true /* include Column Stats */, schemaWithMetadataColumns()); + + try (CloseableIterable fileScanTasks = scan.planFiles()) { + for (FileScanTask task : fileScanTasks) { + if (!task.deletes().isEmpty()) { + LOG.info("Skipping aggregate pushdown: detected row level deletes"); + return false; + } + + aggregateEvaluator.update(task.file()); + } + } catch (IOException e) { + LOG.info("Skipping aggregate pushdown: ", e); + return false; + } + + if (!aggregateEvaluator.allAggregatorsValid()) { + return false; + } + + pushedAggregateSchema = + SparkSchemaUtil.convert(new Schema(aggregateEvaluator.resultType().fields())); + InternalRow[] pushedAggregateRows = new InternalRow[1]; + StructLike structLike = aggregateEvaluator.result(); + pushedAggregateRows[0] = + new StructInternalRow(aggregateEvaluator.resultType()).setStruct(structLike); + localScan = + new SparkLocalScan(table, pushedAggregateSchema, pushedAggregateRows, filterExpressions); + + return true; + } + + private boolean canPushDownAggregation(Aggregation aggregation) { + if (!(table instanceof BaseTable)) { + return false; + } + + if (!readConf.aggregatePushDownEnabled()) { + return false; + } + + // If group by expression is the same as the partition, the statistics information can still + // be used to calculate min/max/count, will enable aggregate push down in next phase. + // TODO: enable aggregate push down for partition col group by expression + if (aggregation.groupByExpressions().length > 0) { + LOG.info("Skipping aggregate pushdown: group by aggregation push down is not supported"); + return false; + } + + return true; + } + + private boolean metricsModeSupportsAggregatePushDown(List> aggregates) { + MetricsConfig config = MetricsConfig.forTable(table); + for (BoundAggregate aggregate : aggregates) { + String colName = aggregate.columnName(); + if (!colName.equals("*")) { + MetricsModes.MetricsMode mode = config.columnMode(colName); + if (mode instanceof MetricsModes.None) { + LOG.info("Skipping aggregate pushdown: No metrics for column {}", colName); + return false; + } else if (mode instanceof MetricsModes.Counts) { + if (aggregate.op() == Expression.Operation.MAX + || aggregate.op() == Expression.Operation.MIN) { + LOG.info( + "Skipping aggregate pushdown: Cannot produce min or max from count for column {}", + colName); + return false; + } + } else if (mode instanceof MetricsModes.Truncate) { + // lower_bounds and upper_bounds may be truncated, so disable push down + if (aggregate.type().typeId() == Type.TypeID.STRING) { + if (aggregate.op() == Expression.Operation.MAX + || aggregate.op() == Expression.Operation.MIN) { + LOG.info( + "Skipping aggregate pushdown: Cannot produce min or max from truncated values for column {}", + colName); + return false; + } + } + } + } + } + + return true; + } + + @Override + public void pruneColumns(StructType requestedSchema) { + StructType requestedProjection = + new StructType( + Stream.of(requestedSchema.fields()) + .filter(field -> MetadataColumns.nonMetadataColumn(field.name())) + .toArray(StructField[]::new)); + + // the projection should include all columns that will be returned, including those only used in + // filters + this.schema = + SparkSchemaUtil.prune(schema, requestedProjection, filterExpression(), caseSensitive); + + Stream.of(requestedSchema.fields()) + .map(StructField::name) + .filter(MetadataColumns::isMetadataColumn) + .distinct() + .forEach(metaColumns::add); + } + + private Schema schemaWithMetadataColumns() { + // metadata columns + List metadataFields = + metaColumns.stream() + .distinct() + .map(name -> MetadataColumns.metadataColumn(table, name)) + .collect(Collectors.toList()); + Schema metadataSchema = calculateMetadataSchema(metadataFields); + + // schema or rows returned by readers + return TypeUtil.join(schema, metadataSchema); + } + + private Schema calculateMetadataSchema(List metaColumnFields) { + Optional partitionField = + metaColumnFields.stream() + .filter(f -> MetadataColumns.PARTITION_COLUMN_ID == f.fieldId()) + .findFirst(); + + // only calculate potential column id collision if partition metadata column was requested + if (!partitionField.isPresent()) { + return new Schema(metaColumnFields); + } + + Set idsToReassign = + TypeUtil.indexById(partitionField.get().type().asStructType()).keySet(); + + // Calculate used ids by union metadata columns with all base table schemas + Set currentlyUsedIds = + metaColumnFields.stream().map(Types.NestedField::fieldId).collect(Collectors.toSet()); + Set allUsedIds = + table.schemas().values().stream() + .map(currSchema -> TypeUtil.indexById(currSchema.asStruct()).keySet()) + .reduce(currentlyUsedIds, Sets::union); + + // Reassign selected ids to deduplicate with used ids. + AtomicInteger nextId = new AtomicInteger(); + return new Schema( + metaColumnFields, + table.schema().identifierFieldIds(), + oldId -> { + if (!idsToReassign.contains(oldId)) { + return oldId; + } + int candidate = nextId.incrementAndGet(); + while (allUsedIds.contains(candidate)) { + candidate = nextId.incrementAndGet(); + } + return candidate; + }); + } + + @Override + public Scan build() { + if (localScan != null) { + return localScan; + } else { + return buildBatchScan(); + } + } + + private Scan buildBatchScan() { + Schema expectedSchema = schemaWithMetadataColumns(); + return new SparkBatchQueryScan( + spark, + table, + buildIcebergBatchScan(false /* not include Column Stats */, expectedSchema), + readConf, + expectedSchema, + filterExpressions, + metricsReporter::scanReport); + } + + private org.apache.iceberg.Scan buildIcebergBatchScan(boolean withStats, Schema expectedSchema) { + Long snapshotId = readConf.snapshotId(); + Long asOfTimestamp = readConf.asOfTimestamp(); + String branch = readConf.branch(); + String tag = readConf.tag(); + + Preconditions.checkArgument( + snapshotId == null || asOfTimestamp == null, + "Cannot set both %s and %s to select which table snapshot to scan", + SparkReadOptions.SNAPSHOT_ID, + SparkReadOptions.AS_OF_TIMESTAMP); + + Long startSnapshotId = readConf.startSnapshotId(); + Long endSnapshotId = readConf.endSnapshotId(); + + if (snapshotId != null || asOfTimestamp != null) { + Preconditions.checkArgument( + startSnapshotId == null && endSnapshotId == null, + "Cannot set %s and %s for incremental scans when either %s or %s is set", + SparkReadOptions.START_SNAPSHOT_ID, + SparkReadOptions.END_SNAPSHOT_ID, + SparkReadOptions.SNAPSHOT_ID, + SparkReadOptions.AS_OF_TIMESTAMP); + } + + Preconditions.checkArgument( + startSnapshotId != null || endSnapshotId == null, + "Cannot set only %s for incremental scans. Please, set %s too.", + SparkReadOptions.END_SNAPSHOT_ID, + SparkReadOptions.START_SNAPSHOT_ID); + + Long startTimestamp = readConf.startTimestamp(); + Long endTimestamp = readConf.endTimestamp(); + Preconditions.checkArgument( + startTimestamp == null && endTimestamp == null, + "Cannot set %s or %s for incremental scans and batch scan. They are only valid for " + + "changelog scans.", + SparkReadOptions.START_TIMESTAMP, + SparkReadOptions.END_TIMESTAMP); + + if (startSnapshotId != null) { + return buildIncrementalAppendScan(startSnapshotId, endSnapshotId, withStats, expectedSchema); + } else { + return buildBatchScan(snapshotId, asOfTimestamp, branch, tag, withStats, expectedSchema); + } + } + + private org.apache.iceberg.Scan buildBatchScan( + Long snapshotId, + Long asOfTimestamp, + String branch, + String tag, + boolean withStats, + Schema expectedSchema) { + BatchScan scan = + newBatchScan() + .caseSensitive(caseSensitive) + .filter(filterExpression()) + .project(expectedSchema) + .metricsReporter(metricsReporter); + + if (withStats) { + scan = scan.includeColumnStats(); + } + + if (snapshotId != null) { + scan = scan.useSnapshot(snapshotId); + } + + if (asOfTimestamp != null) { + scan = scan.asOfTime(asOfTimestamp); + } + + if (branch != null) { + scan = scan.useRef(branch); + } + + if (tag != null) { + scan = scan.useRef(tag); + } + + return configureSplitPlanning(scan); + } + + private org.apache.iceberg.Scan buildIncrementalAppendScan( + long startSnapshotId, Long endSnapshotId, boolean withStats, Schema expectedSchema) { + IncrementalAppendScan scan = + table + .newIncrementalAppendScan() + .fromSnapshotExclusive(startSnapshotId) + .caseSensitive(caseSensitive) + .filter(filterExpression()) + .project(expectedSchema) + .metricsReporter(metricsReporter); + + if (withStats) { + scan = scan.includeColumnStats(); + } + + if (endSnapshotId != null) { + scan = scan.toSnapshot(endSnapshotId); + } + + return configureSplitPlanning(scan); + } + + @SuppressWarnings("CyclomaticComplexity") + public Scan buildChangelogScan() { + Preconditions.checkArgument( + readConf.snapshotId() == null + && readConf.asOfTimestamp() == null + && readConf.branch() == null + && readConf.tag() == null, + "Cannot set neither %s, %s, %s and %s for changelogs", + SparkReadOptions.SNAPSHOT_ID, + SparkReadOptions.AS_OF_TIMESTAMP, + SparkReadOptions.BRANCH, + SparkReadOptions.TAG); + + Long startSnapshotId = readConf.startSnapshotId(); + Long endSnapshotId = readConf.endSnapshotId(); + Long startTimestamp = readConf.startTimestamp(); + Long endTimestamp = readConf.endTimestamp(); + + Preconditions.checkArgument( + !(startSnapshotId != null && startTimestamp != null), + "Cannot set both %s and %s for changelogs", + SparkReadOptions.START_SNAPSHOT_ID, + SparkReadOptions.START_TIMESTAMP); + + Preconditions.checkArgument( + !(endSnapshotId != null && endTimestamp != null), + "Cannot set both %s and %s for changelogs", + SparkReadOptions.END_SNAPSHOT_ID, + SparkReadOptions.END_TIMESTAMP); + + if (startTimestamp != null && endTimestamp != null) { + Preconditions.checkArgument( + startTimestamp < endTimestamp, + "Cannot set %s to be greater than %s for changelogs", + SparkReadOptions.START_TIMESTAMP, + SparkReadOptions.END_TIMESTAMP); + } + + boolean emptyScan = false; + if (startTimestamp != null) { + startSnapshotId = getStartSnapshotId(startTimestamp); + if (startSnapshotId == null && endTimestamp == null) { + emptyScan = true; + } + } + + if (endTimestamp != null) { + endSnapshotId = getEndSnapshotId(endTimestamp); + if ((startSnapshotId == null && endSnapshotId == null) + || (startSnapshotId != null && startSnapshotId.equals(endSnapshotId))) { + emptyScan = true; + } + } + + Schema expectedSchema = schemaWithMetadataColumns(); + + IncrementalChangelogScan scan = + table + .newIncrementalChangelogScan() + .caseSensitive(caseSensitive) + .filter(filterExpression()) + .project(expectedSchema) + .metricsReporter(metricsReporter); + + if (startSnapshotId != null) { + scan = scan.fromSnapshotExclusive(startSnapshotId); + } + + if (endSnapshotId != null) { + scan = scan.toSnapshot(endSnapshotId); + } + + scan = configureSplitPlanning(scan); + + return new SparkChangelogScan( + spark, table, scan, readConf, expectedSchema, filterExpressions, emptyScan); + } + + private Long getStartSnapshotId(Long startTimestamp) { + Snapshot oldestSnapshotAfter = SnapshotUtil.oldestAncestorAfter(table, startTimestamp); + + if (oldestSnapshotAfter == null) { + return null; + } else if (oldestSnapshotAfter.timestampMillis() == startTimestamp) { + return oldestSnapshotAfter.snapshotId(); + } else { + return oldestSnapshotAfter.parentId(); + } + } + + private Long getEndSnapshotId(Long endTimestamp) { + Long endSnapshotId = null; + for (Snapshot snapshot : SnapshotUtil.currentAncestors(table)) { + if (snapshot.timestampMillis() <= endTimestamp) { + endSnapshotId = snapshot.snapshotId(); + break; + } + } + return endSnapshotId; + } + + public Scan buildMergeOnReadScan() { + Preconditions.checkArgument( + readConf.snapshotId() == null && readConf.asOfTimestamp() == null && readConf.tag() == null, + "Cannot set time travel options %s, %s, %s for row-level command scans", + SparkReadOptions.SNAPSHOT_ID, + SparkReadOptions.AS_OF_TIMESTAMP, + SparkReadOptions.TAG); + + Preconditions.checkArgument( + readConf.startSnapshotId() == null && readConf.endSnapshotId() == null, + "Cannot set incremental scan options %s and %s for row-level command scans", + SparkReadOptions.START_SNAPSHOT_ID, + SparkReadOptions.END_SNAPSHOT_ID); + + Snapshot snapshot = SnapshotUtil.latestSnapshot(table, readConf.branch()); + + if (snapshot == null) { + return new SparkBatchQueryScan( + spark, + table, + null, + readConf, + schemaWithMetadataColumns(), + filterExpressions, + metricsReporter::scanReport); + } + + // remember the current snapshot ID for commit validation + long snapshotId = snapshot.snapshotId(); + + CaseInsensitiveStringMap adjustedOptions = + Spark3Util.setOption(SparkReadOptions.SNAPSHOT_ID, Long.toString(snapshotId), options); + SparkReadConf adjustedReadConf = + new SparkReadConf(spark, table, readConf.branch(), adjustedOptions); + + Schema expectedSchema = schemaWithMetadataColumns(); + + BatchScan scan = + newBatchScan() + .useSnapshot(snapshotId) + .caseSensitive(caseSensitive) + .filter(filterExpression()) + .project(expectedSchema) + .metricsReporter(metricsReporter); + + scan = configureSplitPlanning(scan); + + return new SparkBatchQueryScan( + spark, + table, + scan, + adjustedReadConf, + expectedSchema, + filterExpressions, + metricsReporter::scanReport); + } + + public Scan buildCopyOnWriteScan() { + Snapshot snapshot = SnapshotUtil.latestSnapshot(table, readConf.branch()); + + if (snapshot == null) { + return new SparkCopyOnWriteScan( + spark, + table, + readConf, + schemaWithMetadataColumns(), + filterExpressions, + metricsReporter::scanReport); + } + + Schema expectedSchema = schemaWithMetadataColumns(); + + BatchScan scan = + table + .newBatchScan() + .useSnapshot(snapshot.snapshotId()) + .ignoreResiduals() + .caseSensitive(caseSensitive) + .filter(filterExpression()) + .project(expectedSchema) + .metricsReporter(metricsReporter); + + scan = configureSplitPlanning(scan); + + return new SparkCopyOnWriteScan( + spark, + table, + scan, + snapshot, + readConf, + expectedSchema, + filterExpressions, + metricsReporter::scanReport); + } + + private > T configureSplitPlanning(T scan) { + T configuredScan = scan; + + Long splitSize = readConf.splitSizeOption(); + if (splitSize != null) { + configuredScan = configuredScan.option(TableProperties.SPLIT_SIZE, String.valueOf(splitSize)); + } + + Integer splitLookback = readConf.splitLookbackOption(); + if (splitLookback != null) { + configuredScan = + configuredScan.option(TableProperties.SPLIT_LOOKBACK, String.valueOf(splitLookback)); + } + + Long splitOpenFileCost = readConf.splitOpenFileCostOption(); + if (splitOpenFileCost != null) { + configuredScan = + configuredScan.option( + TableProperties.SPLIT_OPEN_FILE_COST, String.valueOf(splitOpenFileCost)); + } + + return configuredScan; + } + + @Override + public Statistics estimateStatistics() { + return ((SupportsReportStatistics) build()).estimateStatistics(); + } + + @Override + public StructType readSchema() { + return build().readSchema(); + } + + private BatchScan newBatchScan() { + if (table instanceof BaseTable && readConf.distributedPlanningEnabled()) { + return new SparkDistributedDataScan(spark, table, readConf); + } else { + return table.newBatchScan(); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkStagedScan.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkStagedScan.java new file mode 100644 index 000000000000..fd299ade7fdc --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkStagedScan.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.List; +import java.util.Objects; +import org.apache.iceberg.ScanTask; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.ScanTaskSetManager; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.iceberg.util.TableScanUtil; +import org.apache.spark.sql.SparkSession; + +class SparkStagedScan extends SparkScan { + + private final String taskSetId; + private final long splitSize; + private final int splitLookback; + private final long openFileCost; + + private List> taskGroups = null; // lazy cache of tasks + + SparkStagedScan(SparkSession spark, Table table, Schema expectedSchema, SparkReadConf readConf) { + super(spark, table, readConf, expectedSchema, ImmutableList.of(), null); + this.taskSetId = readConf.scanTaskSetId(); + this.splitSize = readConf.splitSize(); + this.splitLookback = readConf.splitLookback(); + this.openFileCost = readConf.splitOpenFileCost(); + } + + @Override + protected List> taskGroups() { + if (taskGroups == null) { + ScanTaskSetManager taskSetManager = ScanTaskSetManager.get(); + List tasks = taskSetManager.fetchTasks(table(), taskSetId); + ValidationException.check( + tasks != null, + "Task set manager has no tasks for table %s with task set ID %s", + table(), + taskSetId); + + this.taskGroups = TableScanUtil.planTaskGroups(tasks, splitSize, splitLookback, openFileCost); + } + return taskGroups; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + + if (other == null || getClass() != other.getClass()) { + return false; + } + + SparkStagedScan that = (SparkStagedScan) other; + return table().name().equals(that.table().name()) + && Objects.equals(taskSetId, that.taskSetId) + && readSchema().equals(that.readSchema()) + && Objects.equals(splitSize, that.splitSize) + && Objects.equals(splitLookback, that.splitLookback) + && Objects.equals(openFileCost, that.openFileCost); + } + + @Override + public int hashCode() { + return Objects.hash( + table().name(), taskSetId, readSchema(), splitSize, splitSize, openFileCost); + } + + @Override + public String toString() { + return String.format( + "IcebergStagedScan(table=%s, type=%s, taskSetID=%s, caseSensitive=%s)", + table(), expectedSchema().asStruct(), taskSetId, caseSensitive()); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkStagedScanBuilder.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkStagedScanBuilder.java new file mode 100644 index 000000000000..c5c86c3ebf28 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkStagedScanBuilder.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.connector.read.SupportsPushDownRequiredColumns; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +class SparkStagedScanBuilder implements ScanBuilder, SupportsPushDownRequiredColumns { + + private final SparkSession spark; + private final Table table; + private final SparkReadConf readConf; + private final List metaColumns = Lists.newArrayList(); + + private Schema schema; + + SparkStagedScanBuilder(SparkSession spark, Table table, CaseInsensitiveStringMap options) { + this.spark = spark; + this.table = table; + this.readConf = new SparkReadConf(spark, table, options); + this.schema = table.schema(); + } + + @Override + public Scan build() { + return new SparkStagedScan(spark, table, schemaWithMetadataColumns(), readConf); + } + + @Override + public void pruneColumns(StructType requestedSchema) { + StructType requestedProjection = removeMetaColumns(requestedSchema); + this.schema = SparkSchemaUtil.prune(schema, requestedProjection); + + Stream.of(requestedSchema.fields()) + .map(StructField::name) + .filter(MetadataColumns::isMetadataColumn) + .distinct() + .forEach(metaColumns::add); + } + + private StructType removeMetaColumns(StructType structType) { + return new StructType( + Stream.of(structType.fields()) + .filter(field -> MetadataColumns.nonMetadataColumn(field.name())) + .toArray(StructField[]::new)); + } + + private Schema schemaWithMetadataColumns() { + // metadata columns + List fields = + metaColumns.stream() + .distinct() + .map(name -> MetadataColumns.metadataColumn(table, name)) + .collect(Collectors.toList()); + Schema meta = new Schema(fields); + + // schema of rows returned by readers + return TypeUtil.join(schema, meta); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkTable.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkTable.java new file mode 100644 index 000000000000..bbc7434138ed --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkTable.java @@ -0,0 +1,441 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.TableProperties.CURRENT_SNAPSHOT_ID; +import static org.apache.iceberg.TableProperties.FORMAT_VERSION; + +import java.io.IOException; +import java.util.Map; +import java.util.Set; +import org.apache.iceberg.BaseMetadataTable; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DeleteFiles; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Partitioning; +import org.apache.iceberg.PositionDeletesTable; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SnapshotRef; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.TableScan; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.Evaluator; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.ExpressionUtil; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.expressions.Projections; +import org.apache.iceberg.expressions.StrictMetricsEvaluator; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.CommitMetadata; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkTableUtil; +import org.apache.iceberg.spark.SparkUtil; +import org.apache.iceberg.spark.SparkV2Filters; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.catalog.MetadataColumn; +import org.apache.spark.sql.connector.catalog.SupportsDeleteV2; +import org.apache.spark.sql.connector.catalog.SupportsMetadataColumns; +import org.apache.spark.sql.connector.catalog.SupportsRead; +import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations; +import org.apache.spark.sql.connector.catalog.SupportsWrite; +import org.apache.spark.sql.connector.catalog.TableCapability; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.connector.expressions.filter.Predicate; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.RowLevelOperationBuilder; +import org.apache.spark.sql.connector.write.RowLevelOperationInfo; +import org.apache.spark.sql.connector.write.WriteBuilder; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class SparkTable + implements org.apache.spark.sql.connector.catalog.Table, + SupportsRead, + SupportsWrite, + SupportsDeleteV2, + SupportsRowLevelOperations, + SupportsMetadataColumns { + + private static final Logger LOG = LoggerFactory.getLogger(SparkTable.class); + + private static final Set RESERVED_PROPERTIES = + ImmutableSet.of( + "provider", + "format", + CURRENT_SNAPSHOT_ID, + "location", + FORMAT_VERSION, + "sort-order", + "identifier-fields"); + private static final Set CAPABILITIES = + ImmutableSet.of( + TableCapability.BATCH_READ, + TableCapability.BATCH_WRITE, + TableCapability.MICRO_BATCH_READ, + TableCapability.STREAMING_WRITE, + TableCapability.OVERWRITE_BY_FILTER, + TableCapability.OVERWRITE_DYNAMIC); + private static final Set CAPABILITIES_WITH_ACCEPT_ANY_SCHEMA = + ImmutableSet.builder() + .addAll(CAPABILITIES) + .add(TableCapability.ACCEPT_ANY_SCHEMA) + .build(); + + private final Table icebergTable; + private final Long snapshotId; + private final boolean refreshEagerly; + private final Set capabilities; + private String branch; + private StructType lazyTableSchema = null; + private SparkSession lazySpark = null; + + public SparkTable(Table icebergTable, boolean refreshEagerly) { + this(icebergTable, (Long) null, refreshEagerly); + } + + public SparkTable(Table icebergTable, String branch, boolean refreshEagerly) { + this(icebergTable, refreshEagerly); + this.branch = branch; + ValidationException.check( + branch == null + || SnapshotRef.MAIN_BRANCH.equals(branch) + || icebergTable.snapshot(branch) != null, + "Cannot use branch (does not exist): %s", + branch); + } + + public SparkTable(Table icebergTable, Long snapshotId, boolean refreshEagerly) { + this.icebergTable = icebergTable; + this.snapshotId = snapshotId; + this.refreshEagerly = refreshEagerly; + + boolean acceptAnySchema = + PropertyUtil.propertyAsBoolean( + icebergTable.properties(), + TableProperties.SPARK_WRITE_ACCEPT_ANY_SCHEMA, + TableProperties.SPARK_WRITE_ACCEPT_ANY_SCHEMA_DEFAULT); + this.capabilities = acceptAnySchema ? CAPABILITIES_WITH_ACCEPT_ANY_SCHEMA : CAPABILITIES; + } + + private SparkSession sparkSession() { + if (lazySpark == null) { + this.lazySpark = SparkSession.active(); + } + + return lazySpark; + } + + public Table table() { + return icebergTable; + } + + @Override + public String name() { + return icebergTable.toString(); + } + + public Long snapshotId() { + return snapshotId; + } + + public String branch() { + return branch; + } + + public SparkTable copyWithSnapshotId(long newSnapshotId) { + return new SparkTable(icebergTable, newSnapshotId, refreshEagerly); + } + + public SparkTable copyWithBranch(String targetBranch) { + return new SparkTable(icebergTable, targetBranch, refreshEagerly); + } + + private Schema snapshotSchema() { + if (icebergTable instanceof BaseMetadataTable) { + return icebergTable.schema(); + } else if (branch != null) { + return SnapshotUtil.schemaFor(icebergTable, branch); + } else { + return SnapshotUtil.schemaFor(icebergTable, snapshotId, null); + } + } + + @Override + public StructType schema() { + if (lazyTableSchema == null) { + this.lazyTableSchema = SparkSchemaUtil.convert(snapshotSchema()); + } + + return lazyTableSchema; + } + + @Override + public Transform[] partitioning() { + return Spark3Util.toTransforms(icebergTable.spec()); + } + + @Override + public Map properties() { + ImmutableMap.Builder propsBuilder = ImmutableMap.builder(); + + String fileFormat = + icebergTable + .properties() + .getOrDefault( + TableProperties.DEFAULT_FILE_FORMAT, TableProperties.DEFAULT_FILE_FORMAT_DEFAULT); + propsBuilder.put("format", "iceberg/" + fileFormat); + propsBuilder.put("provider", "iceberg"); + String currentSnapshotId = + icebergTable.currentSnapshot() != null + ? String.valueOf(icebergTable.currentSnapshot().snapshotId()) + : "none"; + propsBuilder.put(CURRENT_SNAPSHOT_ID, currentSnapshotId); + propsBuilder.put("location", icebergTable.location()); + + if (icebergTable instanceof BaseTable) { + TableOperations ops = ((BaseTable) icebergTable).operations(); + propsBuilder.put(FORMAT_VERSION, String.valueOf(ops.current().formatVersion())); + } + + if (!icebergTable.sortOrder().isUnsorted()) { + propsBuilder.put("sort-order", Spark3Util.describe(icebergTable.sortOrder())); + } + + Set identifierFields = icebergTable.schema().identifierFieldNames(); + if (!identifierFields.isEmpty()) { + propsBuilder.put("identifier-fields", "[" + String.join(",", identifierFields) + "]"); + } + + icebergTable.properties().entrySet().stream() + .filter(entry -> !RESERVED_PROPERTIES.contains(entry.getKey())) + .forEach(propsBuilder::put); + + return propsBuilder.build(); + } + + @Override + public Set capabilities() { + return capabilities; + } + + @Override + public MetadataColumn[] metadataColumns() { + DataType sparkPartitionType = SparkSchemaUtil.convert(Partitioning.partitionType(table())); + return new MetadataColumn[] { + new SparkMetadataColumn(MetadataColumns.SPEC_ID.name(), DataTypes.IntegerType, false), + new SparkMetadataColumn(MetadataColumns.PARTITION_COLUMN_NAME, sparkPartitionType, true), + new SparkMetadataColumn(MetadataColumns.FILE_PATH.name(), DataTypes.StringType, false), + new SparkMetadataColumn(MetadataColumns.ROW_POSITION.name(), DataTypes.LongType, false), + new SparkMetadataColumn(MetadataColumns.IS_DELETED.name(), DataTypes.BooleanType, false) + }; + } + + @Override + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { + if (options.containsKey(SparkReadOptions.SCAN_TASK_SET_ID)) { + return new SparkStagedScanBuilder(sparkSession(), icebergTable, options); + } + + if (refreshEagerly) { + icebergTable.refresh(); + } + + CaseInsensitiveStringMap scanOptions = + branch != null ? options : addSnapshotId(options, snapshotId); + return new SparkScanBuilder( + sparkSession(), icebergTable, branch, snapshotSchema(), scanOptions); + } + + @Override + public WriteBuilder newWriteBuilder(LogicalWriteInfo info) { + Preconditions.checkArgument( + snapshotId == null, "Cannot write to table at a specific snapshot: %s", snapshotId); + + if (icebergTable instanceof PositionDeletesTable) { + return new SparkPositionDeletesRewriteBuilder(sparkSession(), icebergTable, branch, info); + } else { + return new SparkWriteBuilder(sparkSession(), icebergTable, branch, info); + } + } + + @Override + public RowLevelOperationBuilder newRowLevelOperationBuilder(RowLevelOperationInfo info) { + return new SparkRowLevelOperationBuilder(sparkSession(), icebergTable, branch, info); + } + + @Override + public boolean canDeleteWhere(Predicate[] predicates) { + Preconditions.checkArgument( + snapshotId == null, "Cannot delete from table at a specific snapshot: %s", snapshotId); + + Expression deleteExpr = Expressions.alwaysTrue(); + + for (Predicate predicate : predicates) { + Expression expr = SparkV2Filters.convert(predicate); + if (expr != null) { + deleteExpr = Expressions.and(deleteExpr, expr); + } else { + return false; + } + } + + return canDeleteUsingMetadata(deleteExpr); + } + + // a metadata delete is possible iff matching files can be deleted entirely + private boolean canDeleteUsingMetadata(Expression deleteExpr) { + boolean caseSensitive = SparkUtil.caseSensitive(sparkSession()); + + if (ExpressionUtil.selectsPartitions(deleteExpr, table(), caseSensitive)) { + return true; + } + + TableScan scan = + table() + .newScan() + .filter(deleteExpr) + .caseSensitive(caseSensitive) + .includeColumnStats() + .ignoreResiduals(); + + if (branch != null) { + scan = scan.useRef(branch); + } + + try (CloseableIterable tasks = scan.planFiles()) { + Map evaluators = Maps.newHashMap(); + StrictMetricsEvaluator metricsEvaluator = + new StrictMetricsEvaluator(SnapshotUtil.schemaFor(table(), branch), deleteExpr); + + return Iterables.all( + tasks, + task -> { + DataFile file = task.file(); + PartitionSpec spec = task.spec(); + Evaluator evaluator = + evaluators.computeIfAbsent( + spec.specId(), + specId -> + new Evaluator( + spec.partitionType(), Projections.strict(spec).project(deleteExpr))); + return evaluator.eval(file.partition()) || metricsEvaluator.eval(file); + }); + + } catch (IOException ioe) { + LOG.warn("Failed to close task iterable", ioe); + return false; + } + } + + @Override + public void deleteWhere(Predicate[] predicates) { + Expression deleteExpr = SparkV2Filters.convert(predicates); + + if (deleteExpr == Expressions.alwaysFalse()) { + LOG.info("Skipping the delete operation as the condition is always false"); + return; + } + + DeleteFiles deleteFiles = + icebergTable + .newDelete() + .set("spark.app.id", sparkSession().sparkContext().applicationId()) + .deleteFromRowFilter(deleteExpr); + + if (SparkTableUtil.wapEnabled(table())) { + branch = SparkTableUtil.determineWriteBranch(sparkSession(), branch); + } + + if (branch != null) { + deleteFiles.toBranch(branch); + } + + if (!CommitMetadata.commitProperties().isEmpty()) { + CommitMetadata.commitProperties().forEach(deleteFiles::set); + } + + deleteFiles.commit(); + } + + @Override + public String toString() { + return icebergTable.toString(); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } else if (other == null || getClass() != other.getClass()) { + return false; + } + + // use only name in order to correctly invalidate Spark cache + SparkTable that = (SparkTable) other; + return icebergTable.name().equals(that.icebergTable.name()); + } + + @Override + public int hashCode() { + // use only name in order to correctly invalidate Spark cache + return icebergTable.name().hashCode(); + } + + private static CaseInsensitiveStringMap addSnapshotId( + CaseInsensitiveStringMap options, Long snapshotId) { + if (snapshotId != null) { + String snapshotIdFromOptions = options.get(SparkReadOptions.SNAPSHOT_ID); + String value = snapshotId.toString(); + Preconditions.checkArgument( + snapshotIdFromOptions == null || snapshotIdFromOptions.equals(value), + "Cannot override snapshot ID more than once: %s", + snapshotIdFromOptions); + + Map scanOptions = Maps.newHashMap(); + scanOptions.putAll(options.asCaseSensitiveMap()); + scanOptions.put(SparkReadOptions.SNAPSHOT_ID, value); + scanOptions.remove(SparkReadOptions.AS_OF_TIMESTAMP); + scanOptions.remove(SparkReadOptions.BRANCH); + scanOptions.remove(SparkReadOptions.TAG); + + return new CaseInsensitiveStringMap(scanOptions); + } + + return options; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkView.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkView.java new file mode 100644 index 000000000000..47e57295363d --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkView.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.TableProperties.FORMAT_VERSION; + +import java.util.Map; +import java.util.Set; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.view.BaseView; +import org.apache.iceberg.view.SQLViewRepresentation; +import org.apache.iceberg.view.View; +import org.apache.iceberg.view.ViewOperations; +import org.apache.spark.sql.types.StructType; + +public class SparkView implements org.apache.spark.sql.connector.catalog.View { + + public static final String QUERY_COLUMN_NAMES = "spark.query-column-names"; + public static final Set RESERVED_PROPERTIES = + ImmutableSet.of("provider", "location", FORMAT_VERSION, QUERY_COLUMN_NAMES); + + private final View icebergView; + private final String catalogName; + private StructType lazySchema = null; + + public SparkView(String catalogName, View icebergView) { + this.catalogName = catalogName; + this.icebergView = icebergView; + } + + public View view() { + return icebergView; + } + + @Override + public String name() { + return icebergView.name(); + } + + @Override + public String query() { + SQLViewRepresentation sqlRepr = icebergView.sqlFor("spark"); + Preconditions.checkState(sqlRepr != null, "Cannot load SQL for view %s", name()); + return sqlRepr.sql(); + } + + @Override + public String currentCatalog() { + return icebergView.currentVersion().defaultCatalog() != null + ? icebergView.currentVersion().defaultCatalog() + : catalogName; + } + + @Override + public String[] currentNamespace() { + return icebergView.currentVersion().defaultNamespace().levels(); + } + + @Override + public StructType schema() { + if (null == lazySchema) { + this.lazySchema = SparkSchemaUtil.convert(icebergView.schema()); + } + + return lazySchema; + } + + @Override + public String[] queryColumnNames() { + return icebergView.properties().containsKey(QUERY_COLUMN_NAMES) + ? icebergView.properties().get(QUERY_COLUMN_NAMES).split(",") + : new String[0]; + } + + @Override + public String[] columnAliases() { + return icebergView.schema().columns().stream() + .map(Types.NestedField::name) + .toArray(String[]::new); + } + + @Override + public String[] columnComments() { + return icebergView.schema().columns().stream() + .map(Types.NestedField::doc) + .toArray(String[]::new); + } + + @Override + public Map properties() { + ImmutableMap.Builder propsBuilder = ImmutableMap.builder(); + + propsBuilder.put("provider", "iceberg"); + propsBuilder.put("location", icebergView.location()); + + if (icebergView instanceof BaseView) { + ViewOperations ops = ((BaseView) icebergView).operations(); + propsBuilder.put(FORMAT_VERSION, String.valueOf(ops.current().formatVersion())); + } + + icebergView.properties().entrySet().stream() + .filter(entry -> !RESERVED_PROPERTIES.contains(entry.getKey())) + .forEach(propsBuilder::put); + + return propsBuilder.build(); + } + + @Override + public String toString() { + return icebergView.toString(); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } else if (other == null || getClass() != other.getClass()) { + return false; + } + + // use only name in order to correctly invalidate Spark cache + SparkView that = (SparkView) other; + return icebergView.name().equals(that.icebergView.name()); + } + + @Override + public int hashCode() { + // use only name in order to correctly invalidate Spark cache + return icebergView.name().hashCode(); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java new file mode 100644 index 000000000000..cc3dc592ecee --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java @@ -0,0 +1,807 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.IsolationLevel.SERIALIZABLE; +import static org.apache.iceberg.IsolationLevel.SNAPSHOT; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.iceberg.AppendFiles; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.IsolationLevel; +import org.apache.iceberg.OverwriteFiles; +import org.apache.iceberg.PartitionKey; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.ReplacePartitions; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.SnapshotUpdate; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.CleanableFailure; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.io.ClusteredDataWriter; +import org.apache.iceberg.io.DataWriteResult; +import org.apache.iceberg.io.FanoutDataWriter; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.FileWriter; +import org.apache.iceberg.io.OutputFileFactory; +import org.apache.iceberg.io.PartitioningWriter; +import org.apache.iceberg.io.RollingDataWriter; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.CommitMetadata; +import org.apache.iceberg.spark.FileRewriteCoordinator; +import org.apache.iceberg.spark.SparkWriteConf; +import org.apache.iceberg.spark.SparkWriteRequirements; +import org.apache.iceberg.util.DataFileSet; +import org.apache.spark.TaskContext; +import org.apache.spark.TaskContext$; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.executor.OutputMetrics; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.distributions.Distribution; +import org.apache.spark.sql.connector.expressions.SortOrder; +import org.apache.spark.sql.connector.write.BatchWrite; +import org.apache.spark.sql.connector.write.DataWriter; +import org.apache.spark.sql.connector.write.DataWriterFactory; +import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.PhysicalWriteInfo; +import org.apache.spark.sql.connector.write.RequiresDistributionAndOrdering; +import org.apache.spark.sql.connector.write.Write; +import org.apache.spark.sql.connector.write.WriterCommitMessage; +import org.apache.spark.sql.connector.write.streaming.StreamingDataWriterFactory; +import org.apache.spark.sql.connector.write.streaming.StreamingWrite; +import org.apache.spark.sql.types.StructType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +abstract class SparkWrite implements Write, RequiresDistributionAndOrdering { + private static final Logger LOG = LoggerFactory.getLogger(SparkWrite.class); + + private final JavaSparkContext sparkContext; + private final SparkWriteConf writeConf; + private final Table table; + private final String queryId; + private final FileFormat format; + private final String applicationId; + private final boolean wapEnabled; + private final String wapId; + private final int outputSpecId; + private final String branch; + private final long targetFileSize; + private final Schema writeSchema; + private final StructType dsSchema; + private final Map extraSnapshotMetadata; + private final boolean useFanoutWriter; + private final SparkWriteRequirements writeRequirements; + private final Map writeProperties; + + private boolean cleanupOnAbort = false; + + SparkWrite( + SparkSession spark, + Table table, + SparkWriteConf writeConf, + LogicalWriteInfo writeInfo, + String applicationId, + Schema writeSchema, + StructType dsSchema, + SparkWriteRequirements writeRequirements) { + this.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + this.table = table; + this.writeConf = writeConf; + this.queryId = writeInfo.queryId(); + this.format = writeConf.dataFileFormat(); + this.applicationId = applicationId; + this.wapEnabled = writeConf.wapEnabled(); + this.wapId = writeConf.wapId(); + this.branch = writeConf.branch(); + this.targetFileSize = writeConf.targetDataFileSize(); + this.writeSchema = writeSchema; + this.dsSchema = dsSchema; + this.extraSnapshotMetadata = writeConf.extraSnapshotMetadata(); + this.useFanoutWriter = writeConf.useFanoutWriter(writeRequirements); + this.writeRequirements = writeRequirements; + this.outputSpecId = writeConf.outputSpecId(); + this.writeProperties = writeConf.writeProperties(); + } + + @Override + public Distribution requiredDistribution() { + Distribution distribution = writeRequirements.distribution(); + LOG.info("Requesting {} as write distribution for table {}", distribution, table.name()); + return distribution; + } + + @Override + public boolean distributionStrictlyRequired() { + return false; + } + + @Override + public SortOrder[] requiredOrdering() { + SortOrder[] ordering = writeRequirements.ordering(); + LOG.info("Requesting {} as write ordering for table {}", ordering, table.name()); + return ordering; + } + + @Override + public long advisoryPartitionSizeInBytes() { + long size = writeRequirements.advisoryPartitionSize(); + LOG.info("Requesting {} bytes advisory partition size for table {}", size, table.name()); + return size; + } + + BatchWrite asBatchAppend() { + return new BatchAppend(); + } + + BatchWrite asDynamicOverwrite() { + return new DynamicOverwrite(); + } + + BatchWrite asOverwriteByFilter(Expression overwriteExpr) { + return new OverwriteByFilter(overwriteExpr); + } + + BatchWrite asCopyOnWriteOperation(SparkCopyOnWriteScan scan, IsolationLevel isolationLevel) { + return new CopyOnWriteOperation(scan, isolationLevel); + } + + BatchWrite asRewrite(String fileSetID) { + return new RewriteFiles(fileSetID); + } + + StreamingWrite asStreamingAppend() { + return new StreamingAppend(); + } + + StreamingWrite asStreamingOverwrite() { + return new StreamingOverwrite(); + } + + // the writer factory works for both batch and streaming + private WriterFactory createWriterFactory() { + // broadcast the table metadata as the writer factory will be sent to executors + Broadcast
tableBroadcast = + sparkContext.broadcast(SerializableTableWithSize.copyOf(table)); + return new WriterFactory( + tableBroadcast, + queryId, + format, + outputSpecId, + targetFileSize, + writeSchema, + dsSchema, + useFanoutWriter, + writeProperties); + } + + private void commitOperation(SnapshotUpdate operation, String description) { + LOG.info("Committing {} to table {}", description, table); + if (applicationId != null) { + operation.set("spark.app.id", applicationId); + } + + if (!extraSnapshotMetadata.isEmpty()) { + extraSnapshotMetadata.forEach(operation::set); + } + + if (!CommitMetadata.commitProperties().isEmpty()) { + CommitMetadata.commitProperties().forEach(operation::set); + } + + if (wapEnabled && wapId != null) { + // write-audit-publish is enabled for this table and job + // stage the changes without changing the current snapshot + operation.set(SnapshotSummary.STAGED_WAP_ID_PROP, wapId); + operation.stageOnly(); + } + + if (branch != null) { + operation.toBranch(branch); + } + + try { + long start = System.currentTimeMillis(); + operation.commit(); // abort is automatically called if this fails + long duration = System.currentTimeMillis() - start; + LOG.info("Committed in {} ms", duration); + } catch (Exception e) { + cleanupOnAbort = e instanceof CleanableFailure; + throw e; + } + } + + private void abort(WriterCommitMessage[] messages) { + if (cleanupOnAbort) { + SparkCleanupUtil.deleteFiles("job abort", table.io(), files(messages)); + } else { + LOG.warn("Skipping cleanup of written files"); + } + } + + private List files(WriterCommitMessage[] messages) { + List files = Lists.newArrayList(); + + for (WriterCommitMessage message : messages) { + if (message != null) { + TaskCommit taskCommit = (TaskCommit) message; + files.addAll(Arrays.asList(taskCommit.files())); + } + } + + return files; + } + + @Override + public String toString() { + return String.format("IcebergWrite(table=%s, format=%s)", table, format); + } + + private abstract class BaseBatchWrite implements BatchWrite { + @Override + public DataWriterFactory createBatchWriterFactory(PhysicalWriteInfo info) { + return createWriterFactory(); + } + + @Override + public boolean useCommitCoordinator() { + return false; + } + + @Override + public void abort(WriterCommitMessage[] messages) { + SparkWrite.this.abort(messages); + } + + @Override + public String toString() { + return String.format("IcebergBatchWrite(table=%s, format=%s)", table, format); + } + } + + private class BatchAppend extends BaseBatchWrite { + @Override + public void commit(WriterCommitMessage[] messages) { + AppendFiles append = table.newAppend(); + + int numFiles = 0; + for (DataFile file : files(messages)) { + numFiles += 1; + append.appendFile(file); + } + + commitOperation(append, String.format("append with %d new data files", numFiles)); + } + } + + private class DynamicOverwrite extends BaseBatchWrite { + @Override + public void commit(WriterCommitMessage[] messages) { + List files = files(messages); + + if (files.isEmpty()) { + LOG.info("Dynamic overwrite is empty, skipping commit"); + return; + } + + ReplacePartitions dynamicOverwrite = table.newReplacePartitions(); + IsolationLevel isolationLevel = writeConf.isolationLevel(); + Long validateFromSnapshotId = writeConf.validateFromSnapshotId(); + + if (isolationLevel != null && validateFromSnapshotId != null) { + dynamicOverwrite.validateFromSnapshot(validateFromSnapshotId); + } + + if (isolationLevel == SERIALIZABLE) { + dynamicOverwrite.validateNoConflictingData(); + dynamicOverwrite.validateNoConflictingDeletes(); + + } else if (isolationLevel == SNAPSHOT) { + dynamicOverwrite.validateNoConflictingDeletes(); + } + + int numFiles = 0; + for (DataFile file : files) { + numFiles += 1; + dynamicOverwrite.addFile(file); + } + + commitOperation( + dynamicOverwrite, + String.format("dynamic partition overwrite with %d new data files", numFiles)); + } + } + + private class OverwriteByFilter extends BaseBatchWrite { + private final Expression overwriteExpr; + + private OverwriteByFilter(Expression overwriteExpr) { + this.overwriteExpr = overwriteExpr; + } + + @Override + public void commit(WriterCommitMessage[] messages) { + OverwriteFiles overwriteFiles = table.newOverwrite(); + overwriteFiles.overwriteByRowFilter(overwriteExpr); + + int numFiles = 0; + for (DataFile file : files(messages)) { + numFiles += 1; + overwriteFiles.addFile(file); + } + + IsolationLevel isolationLevel = writeConf.isolationLevel(); + Long validateFromSnapshotId = writeConf.validateFromSnapshotId(); + + if (isolationLevel != null && validateFromSnapshotId != null) { + overwriteFiles.validateFromSnapshot(validateFromSnapshotId); + } + + if (isolationLevel == SERIALIZABLE) { + overwriteFiles.validateNoConflictingDeletes(); + overwriteFiles.validateNoConflictingData(); + + } else if (isolationLevel == SNAPSHOT) { + overwriteFiles.validateNoConflictingDeletes(); + } + + String commitMsg = + String.format("overwrite by filter %s with %d new data files", overwriteExpr, numFiles); + commitOperation(overwriteFiles, commitMsg); + } + } + + private class CopyOnWriteOperation extends BaseBatchWrite { + private final SparkCopyOnWriteScan scan; + private final IsolationLevel isolationLevel; + + private CopyOnWriteOperation(SparkCopyOnWriteScan scan, IsolationLevel isolationLevel) { + this.scan = scan; + this.isolationLevel = isolationLevel; + } + + private List overwrittenFiles() { + if (scan == null) { + return ImmutableList.of(); + } else { + return scan.tasks().stream().map(FileScanTask::file).collect(Collectors.toList()); + } + } + + private Expression conflictDetectionFilter() { + // the list of filter expressions may be empty but is never null + List scanFilterExpressions = scan.filterExpressions(); + + Expression filter = Expressions.alwaysTrue(); + + for (Expression expr : scanFilterExpressions) { + filter = Expressions.and(filter, expr); + } + + return filter; + } + + @Override + public void commit(WriterCommitMessage[] messages) { + OverwriteFiles overwriteFiles = table.newOverwrite(); + + List overwrittenFiles = overwrittenFiles(); + int numOverwrittenFiles = overwrittenFiles.size(); + for (DataFile overwrittenFile : overwrittenFiles) { + overwriteFiles.deleteFile(overwrittenFile); + } + + int numAddedFiles = 0; + for (DataFile file : files(messages)) { + numAddedFiles += 1; + overwriteFiles.addFile(file); + } + + // the scan may be null if the optimizer replaces it with an empty relation (e.g. false cond) + // no validation is needed in this case as the command does not depend on the table state + if (scan != null) { + if (isolationLevel == SERIALIZABLE) { + commitWithSerializableIsolation(overwriteFiles, numOverwrittenFiles, numAddedFiles); + } else if (isolationLevel == SNAPSHOT) { + commitWithSnapshotIsolation(overwriteFiles, numOverwrittenFiles, numAddedFiles); + } else { + throw new IllegalArgumentException("Unsupported isolation level: " + isolationLevel); + } + + } else { + commitOperation( + overwriteFiles, + String.format("overwrite with %d new data files (no validation)", numAddedFiles)); + } + } + + private void commitWithSerializableIsolation( + OverwriteFiles overwriteFiles, int numOverwrittenFiles, int numAddedFiles) { + Long scanSnapshotId = scan.snapshotId(); + if (scanSnapshotId != null) { + overwriteFiles.validateFromSnapshot(scanSnapshotId); + } + + Expression conflictDetectionFilter = conflictDetectionFilter(); + overwriteFiles.conflictDetectionFilter(conflictDetectionFilter); + overwriteFiles.validateNoConflictingData(); + overwriteFiles.validateNoConflictingDeletes(); + + String commitMsg = + String.format( + "overwrite of %d data files with %d new data files, scanSnapshotId: %d, conflictDetectionFilter: %s", + numOverwrittenFiles, numAddedFiles, scanSnapshotId, conflictDetectionFilter); + commitOperation(overwriteFiles, commitMsg); + } + + private void commitWithSnapshotIsolation( + OverwriteFiles overwriteFiles, int numOverwrittenFiles, int numAddedFiles) { + Long scanSnapshotId = scan.snapshotId(); + if (scanSnapshotId != null) { + overwriteFiles.validateFromSnapshot(scanSnapshotId); + } + + Expression conflictDetectionFilter = conflictDetectionFilter(); + overwriteFiles.conflictDetectionFilter(conflictDetectionFilter); + overwriteFiles.validateNoConflictingDeletes(); + + String commitMsg = + String.format( + "overwrite of %d data files with %d new data files", + numOverwrittenFiles, numAddedFiles); + commitOperation(overwriteFiles, commitMsg); + } + } + + private class RewriteFiles extends BaseBatchWrite { + private final String fileSetID; + + private RewriteFiles(String fileSetID) { + this.fileSetID = fileSetID; + } + + @Override + public void commit(WriterCommitMessage[] messages) { + FileRewriteCoordinator coordinator = FileRewriteCoordinator.get(); + coordinator.stageRewrite(table, fileSetID, DataFileSet.of(files(messages))); + } + } + + private abstract class BaseStreamingWrite implements StreamingWrite { + private static final String QUERY_ID_PROPERTY = "spark.sql.streaming.queryId"; + private static final String EPOCH_ID_PROPERTY = "spark.sql.streaming.epochId"; + + protected abstract String mode(); + + @Override + public StreamingDataWriterFactory createStreamingWriterFactory(PhysicalWriteInfo info) { + return createWriterFactory(); + } + + @Override + public boolean useCommitCoordinator() { + return false; + } + + @Override + public final void commit(long epochId, WriterCommitMessage[] messages) { + LOG.info("Committing epoch {} for query {} in {} mode", epochId, queryId, mode()); + + table.refresh(); + + Long lastCommittedEpochId = findLastCommittedEpochId(); + if (lastCommittedEpochId != null && epochId <= lastCommittedEpochId) { + LOG.info("Skipping epoch {} for query {} as it was already committed", epochId, queryId); + return; + } + + doCommit(epochId, messages); + } + + protected abstract void doCommit(long epochId, WriterCommitMessage[] messages); + + protected void commit(SnapshotUpdate snapshotUpdate, long epochId, String description) { + snapshotUpdate.set(QUERY_ID_PROPERTY, queryId); + snapshotUpdate.set(EPOCH_ID_PROPERTY, Long.toString(epochId)); + commitOperation(snapshotUpdate, description); + } + + private Long findLastCommittedEpochId() { + Snapshot snapshot = table.currentSnapshot(); + Long lastCommittedEpochId = null; + while (snapshot != null) { + Map summary = snapshot.summary(); + String snapshotQueryId = summary.get(QUERY_ID_PROPERTY); + if (queryId.equals(snapshotQueryId)) { + lastCommittedEpochId = Long.valueOf(summary.get(EPOCH_ID_PROPERTY)); + break; + } + Long parentSnapshotId = snapshot.parentId(); + snapshot = parentSnapshotId != null ? table.snapshot(parentSnapshotId) : null; + } + return lastCommittedEpochId; + } + + @Override + public void abort(long epochId, WriterCommitMessage[] messages) { + SparkWrite.this.abort(messages); + } + + @Override + public String toString() { + return String.format("IcebergStreamingWrite(table=%s, format=%s)", table, format); + } + } + + private class StreamingAppend extends BaseStreamingWrite { + @Override + protected String mode() { + return "append"; + } + + @Override + protected void doCommit(long epochId, WriterCommitMessage[] messages) { + AppendFiles append = table.newFastAppend(); + int numFiles = 0; + for (DataFile file : files(messages)) { + append.appendFile(file); + numFiles++; + } + commit(append, epochId, String.format("streaming append with %d new data files", numFiles)); + } + } + + private class StreamingOverwrite extends BaseStreamingWrite { + @Override + protected String mode() { + return "complete"; + } + + @Override + public void doCommit(long epochId, WriterCommitMessage[] messages) { + OverwriteFiles overwriteFiles = table.newOverwrite(); + overwriteFiles.overwriteByRowFilter(Expressions.alwaysTrue()); + int numFiles = 0; + for (DataFile file : files(messages)) { + overwriteFiles.addFile(file); + numFiles++; + } + commit( + overwriteFiles, + epochId, + String.format("streaming complete overwrite with %d new data files", numFiles)); + } + } + + public static class TaskCommit implements WriterCommitMessage { + private final DataFile[] taskFiles; + + TaskCommit(DataFile[] taskFiles) { + this.taskFiles = taskFiles; + } + + // Reports bytesWritten and recordsWritten to the Spark output metrics. + // Can only be called in executor. + void reportOutputMetrics() { + long bytesWritten = 0L; + long recordsWritten = 0L; + for (DataFile dataFile : taskFiles) { + bytesWritten += dataFile.fileSizeInBytes(); + recordsWritten += dataFile.recordCount(); + } + + TaskContext taskContext = TaskContext$.MODULE$.get(); + if (taskContext != null) { + OutputMetrics outputMetrics = taskContext.taskMetrics().outputMetrics(); + outputMetrics.setBytesWritten(bytesWritten); + outputMetrics.setRecordsWritten(recordsWritten); + } + } + + DataFile[] files() { + return taskFiles; + } + } + + private static class WriterFactory implements DataWriterFactory, StreamingDataWriterFactory { + private final Broadcast
tableBroadcast; + private final FileFormat format; + private final int outputSpecId; + private final long targetFileSize; + private final Schema writeSchema; + private final StructType dsSchema; + private final boolean useFanoutWriter; + private final String queryId; + private final Map writeProperties; + + protected WriterFactory( + Broadcast
tableBroadcast, + String queryId, + FileFormat format, + int outputSpecId, + long targetFileSize, + Schema writeSchema, + StructType dsSchema, + boolean useFanoutWriter, + Map writeProperties) { + this.tableBroadcast = tableBroadcast; + this.format = format; + this.outputSpecId = outputSpecId; + this.targetFileSize = targetFileSize; + this.writeSchema = writeSchema; + this.dsSchema = dsSchema; + this.useFanoutWriter = useFanoutWriter; + this.queryId = queryId; + this.writeProperties = writeProperties; + } + + @Override + public DataWriter createWriter(int partitionId, long taskId) { + return createWriter(partitionId, taskId, 0); + } + + @Override + public DataWriter createWriter(int partitionId, long taskId, long epochId) { + Table table = tableBroadcast.value(); + PartitionSpec spec = table.specs().get(outputSpecId); + FileIO io = table.io(); + String operationId = queryId + "-" + epochId; + OutputFileFactory fileFactory = + OutputFileFactory.builderFor(table, partitionId, taskId) + .format(format) + .operationId(operationId) + .build(); + SparkFileWriterFactory writerFactory = + SparkFileWriterFactory.builderFor(table) + .dataFileFormat(format) + .dataSchema(writeSchema) + .dataSparkType(dsSchema) + .writeProperties(writeProperties) + .build(); + + if (spec.isUnpartitioned()) { + return new UnpartitionedDataWriter(writerFactory, fileFactory, io, spec, targetFileSize); + + } else { + return new PartitionedDataWriter( + writerFactory, + fileFactory, + io, + spec, + writeSchema, + dsSchema, + targetFileSize, + useFanoutWriter); + } + } + } + + private static class UnpartitionedDataWriter implements DataWriter { + private final FileWriter delegate; + private final FileIO io; + + private UnpartitionedDataWriter( + SparkFileWriterFactory writerFactory, + OutputFileFactory fileFactory, + FileIO io, + PartitionSpec spec, + long targetFileSize) { + this.delegate = + new RollingDataWriter<>(writerFactory, fileFactory, io, targetFileSize, spec, null); + this.io = io; + } + + @Override + public void write(InternalRow record) throws IOException { + delegate.write(record); + } + + @Override + public WriterCommitMessage commit() throws IOException { + close(); + + DataWriteResult result = delegate.result(); + TaskCommit taskCommit = new TaskCommit(result.dataFiles().toArray(new DataFile[0])); + taskCommit.reportOutputMetrics(); + return taskCommit; + } + + @Override + public void abort() throws IOException { + close(); + + DataWriteResult result = delegate.result(); + SparkCleanupUtil.deleteTaskFiles(io, result.dataFiles()); + } + + @Override + public void close() throws IOException { + delegate.close(); + } + } + + private static class PartitionedDataWriter implements DataWriter { + private final PartitioningWriter delegate; + private final FileIO io; + private final PartitionSpec spec; + private final PartitionKey partitionKey; + private final InternalRowWrapper internalRowWrapper; + + private PartitionedDataWriter( + SparkFileWriterFactory writerFactory, + OutputFileFactory fileFactory, + FileIO io, + PartitionSpec spec, + Schema dataSchema, + StructType dataSparkType, + long targetFileSize, + boolean fanoutEnabled) { + if (fanoutEnabled) { + this.delegate = new FanoutDataWriter<>(writerFactory, fileFactory, io, targetFileSize); + } else { + this.delegate = new ClusteredDataWriter<>(writerFactory, fileFactory, io, targetFileSize); + } + this.io = io; + this.spec = spec; + this.partitionKey = new PartitionKey(spec, dataSchema); + this.internalRowWrapper = new InternalRowWrapper(dataSparkType, dataSchema.asStruct()); + } + + @Override + public void write(InternalRow row) throws IOException { + partitionKey.partition(internalRowWrapper.wrap(row)); + delegate.write(row, spec, partitionKey); + } + + @Override + public WriterCommitMessage commit() throws IOException { + close(); + + DataWriteResult result = delegate.result(); + TaskCommit taskCommit = new TaskCommit(result.dataFiles().toArray(new DataFile[0])); + taskCommit.reportOutputMetrics(); + return taskCommit; + } + + @Override + public void abort() throws IOException { + close(); + + DataWriteResult result = delegate.result(); + SparkCleanupUtil.deleteTaskFiles(io, result.dataFiles()); + } + + @Override + public void close() throws IOException { + delegate.close(); + } + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkWriteBuilder.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkWriteBuilder.java new file mode 100644 index 000000000000..602b692d7352 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkWriteBuilder.java @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.IsolationLevel; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.UpdateSchema; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.SparkFilters; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkUtil; +import org.apache.iceberg.spark.SparkWriteConf; +import org.apache.iceberg.spark.SparkWriteRequirements; +import org.apache.iceberg.types.TypeUtil; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.write.BatchWrite; +import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.RowLevelOperation.Command; +import org.apache.spark.sql.connector.write.SupportsDynamicOverwrite; +import org.apache.spark.sql.connector.write.SupportsOverwrite; +import org.apache.spark.sql.connector.write.Write; +import org.apache.spark.sql.connector.write.WriteBuilder; +import org.apache.spark.sql.connector.write.streaming.StreamingWrite; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.types.StructType; + +class SparkWriteBuilder implements WriteBuilder, SupportsDynamicOverwrite, SupportsOverwrite { + private final SparkSession spark; + private final Table table; + private final SparkWriteConf writeConf; + private final LogicalWriteInfo writeInfo; + private final StructType dsSchema; + private final String overwriteMode; + private final String rewrittenFileSetId; + private boolean overwriteDynamic = false; + private boolean overwriteByFilter = false; + private Expression overwriteExpr = null; + private boolean overwriteFiles = false; + private SparkCopyOnWriteScan copyOnWriteScan = null; + private Command copyOnWriteCommand = null; + private IsolationLevel copyOnWriteIsolationLevel = null; + + SparkWriteBuilder(SparkSession spark, Table table, String branch, LogicalWriteInfo info) { + this.spark = spark; + this.table = table; + this.writeConf = new SparkWriteConf(spark, table, branch, info.options()); + this.writeInfo = info; + this.dsSchema = info.schema(); + this.overwriteMode = writeConf.overwriteMode(); + this.rewrittenFileSetId = writeConf.rewrittenFileSetId(); + } + + public WriteBuilder overwriteFiles(Scan scan, Command command, IsolationLevel isolationLevel) { + Preconditions.checkState(!overwriteByFilter, "Cannot overwrite individual files and by filter"); + Preconditions.checkState( + !overwriteDynamic, "Cannot overwrite individual files and dynamically"); + Preconditions.checkState( + rewrittenFileSetId == null, "Cannot overwrite individual files and rewrite"); + + this.overwriteFiles = true; + this.copyOnWriteScan = (SparkCopyOnWriteScan) scan; + this.copyOnWriteCommand = command; + this.copyOnWriteIsolationLevel = isolationLevel; + return this; + } + + @Override + public WriteBuilder overwriteDynamicPartitions() { + Preconditions.checkState( + !overwriteByFilter, "Cannot overwrite dynamically and by filter: %s", overwriteExpr); + Preconditions.checkState(!overwriteFiles, "Cannot overwrite individual files and dynamically"); + Preconditions.checkState( + rewrittenFileSetId == null, "Cannot overwrite dynamically and rewrite"); + + this.overwriteDynamic = true; + return this; + } + + @Override + public WriteBuilder overwrite(Filter[] filters) { + Preconditions.checkState( + !overwriteFiles, "Cannot overwrite individual files and using filters"); + Preconditions.checkState(rewrittenFileSetId == null, "Cannot overwrite and rewrite"); + + this.overwriteExpr = SparkFilters.convert(filters); + if (overwriteExpr == Expressions.alwaysTrue() && "dynamic".equals(overwriteMode)) { + // use the write option to override truncating the table. use dynamic overwrite instead. + this.overwriteDynamic = true; + } else { + Preconditions.checkState( + !overwriteDynamic, "Cannot overwrite dynamically and by filter: %s", overwriteExpr); + this.overwriteByFilter = true; + } + return this; + } + + @Override + public Write build() { + // Validate + Schema writeSchema = validateOrMergeWriteSchema(table, dsSchema, writeConf); + SparkUtil.validatePartitionTransforms(table.spec()); + + // Get application id + String appId = spark.sparkContext().applicationId(); + + return new SparkWrite( + spark, table, writeConf, writeInfo, appId, writeSchema, dsSchema, writeRequirements()) { + + @Override + public BatchWrite toBatch() { + if (rewrittenFileSetId != null) { + return asRewrite(rewrittenFileSetId); + } else if (overwriteByFilter) { + return asOverwriteByFilter(overwriteExpr); + } else if (overwriteDynamic) { + return asDynamicOverwrite(); + } else if (overwriteFiles) { + return asCopyOnWriteOperation(copyOnWriteScan, copyOnWriteIsolationLevel); + } else { + return asBatchAppend(); + } + } + + @Override + public StreamingWrite toStreaming() { + Preconditions.checkState( + !overwriteDynamic, "Unsupported streaming operation: dynamic partition overwrite"); + Preconditions.checkState( + !overwriteByFilter || overwriteExpr == Expressions.alwaysTrue(), + "Unsupported streaming operation: overwrite by filter: %s", + overwriteExpr); + Preconditions.checkState( + rewrittenFileSetId == null, "Unsupported streaming operation: rewrite"); + + if (overwriteByFilter) { + return asStreamingOverwrite(); + } else { + return asStreamingAppend(); + } + } + }; + } + + private SparkWriteRequirements writeRequirements() { + if (overwriteFiles) { + return writeConf.copyOnWriteRequirements(copyOnWriteCommand); + } else { + return writeConf.writeRequirements(); + } + } + + private static Schema validateOrMergeWriteSchema( + Table table, StructType dsSchema, SparkWriteConf writeConf) { + Schema writeSchema; + boolean caseSensitive = writeConf.caseSensitive(); + if (writeConf.mergeSchema()) { + // convert the dataset schema and assign fresh ids for new fields + Schema newSchema = + SparkSchemaUtil.convertWithFreshIds(table.schema(), dsSchema, caseSensitive); + + // update the table to get final id assignments and validate the changes + UpdateSchema update = + table.updateSchema().caseSensitive(caseSensitive).unionByNameWith(newSchema); + Schema mergedSchema = update.apply(); + + // reconvert the dsSchema without assignment to use the ids assigned by UpdateSchema + writeSchema = SparkSchemaUtil.convert(mergedSchema, dsSchema, caseSensitive); + + TypeUtil.validateWriteSchema( + mergedSchema, writeSchema, writeConf.checkNullability(), writeConf.checkOrdering()); + + // if the validation passed, update the table schema + update.commit(); + } else { + writeSchema = SparkSchemaUtil.convert(table.schema(), dsSchema, caseSensitive); + TypeUtil.validateWriteSchema( + table.schema(), writeSchema, writeConf.checkNullability(), writeConf.checkOrdering()); + } + + return writeSchema; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/StagedSparkTable.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/StagedSparkTable.java new file mode 100644 index 000000000000..b92c02d2b536 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/StagedSparkTable.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.Transaction; +import org.apache.spark.sql.connector.catalog.StagedTable; + +public class StagedSparkTable extends SparkTable implements StagedTable { + private final Transaction transaction; + + public StagedSparkTable(Transaction transaction) { + super(transaction.table(), false); + this.transaction = transaction; + } + + @Override + public void commitStagedChanges() { + transaction.commitTransaction(); + } + + @Override + public void abortStagedChanges() { + // TODO: clean up + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/Stats.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/Stats.java new file mode 100644 index 000000000000..ccf523cb4b05 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/Stats.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Map; +import java.util.OptionalLong; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.read.Statistics; +import org.apache.spark.sql.connector.read.colstats.ColumnStatistics; + +class Stats implements Statistics { + private final OptionalLong sizeInBytes; + private final OptionalLong numRows; + private final Map colstats; + + Stats(long sizeInBytes, long numRows, Map colstats) { + this.sizeInBytes = OptionalLong.of(sizeInBytes); + this.numRows = OptionalLong.of(numRows); + this.colstats = colstats; + } + + @Override + public OptionalLong sizeInBytes() { + return sizeInBytes; + } + + @Override + public OptionalLong numRows() { + return numRows; + } + + @Override + public Map columnStats() { + return colstats; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/StreamingOffset.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/StreamingOffset.java new file mode 100644 index 000000000000..f2088deb1ee3 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/StreamingOffset.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.databind.JsonNode; +import java.io.IOException; +import java.io.InputStream; +import java.io.StringWriter; +import java.io.UncheckedIOException; +import org.apache.iceberg.relocated.com.google.common.base.Objects; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.util.JsonUtil; +import org.apache.spark.sql.connector.read.streaming.Offset; + +class StreamingOffset extends Offset { + static final StreamingOffset START_OFFSET = new StreamingOffset(-1L, -1, false); + + private static final int CURR_VERSION = 1; + private static final String VERSION = "version"; + private static final String SNAPSHOT_ID = "snapshot_id"; + private static final String POSITION = "position"; + private static final String SCAN_ALL_FILES = "scan_all_files"; + + private final long snapshotId; + private final long position; + private final boolean scanAllFiles; + + /** + * An implementation of Spark Structured Streaming Offset, to track the current processed files of + * Iceberg table. + * + * @param snapshotId The current processed snapshot id. + * @param position The position of last scanned file in snapshot. + * @param scanAllFiles whether to scan all files in a snapshot; for example, to read all data when + * starting a stream. + */ + StreamingOffset(long snapshotId, long position, boolean scanAllFiles) { + this.snapshotId = snapshotId; + this.position = position; + this.scanAllFiles = scanAllFiles; + } + + static StreamingOffset fromJson(String json) { + Preconditions.checkNotNull(json, "Cannot parse StreamingOffset JSON: null"); + + try { + JsonNode node = JsonUtil.mapper().readValue(json, JsonNode.class); + return fromJsonNode(node); + } catch (IOException e) { + throw new UncheckedIOException( + String.format("Failed to parse StreamingOffset from JSON string %s", json), e); + } + } + + static StreamingOffset fromJson(InputStream inputStream) { + Preconditions.checkNotNull(inputStream, "Cannot parse StreamingOffset from inputStream: null"); + + JsonNode node; + try { + node = JsonUtil.mapper().readValue(inputStream, JsonNode.class); + } catch (IOException e) { + throw new UncheckedIOException("Failed to read StreamingOffset from json", e); + } + + return fromJsonNode(node); + } + + @Override + public String json() { + StringWriter writer = new StringWriter(); + try { + JsonGenerator generator = JsonUtil.factory().createGenerator(writer); + generator.writeStartObject(); + generator.writeNumberField(VERSION, CURR_VERSION); + generator.writeNumberField(SNAPSHOT_ID, snapshotId); + generator.writeNumberField(POSITION, position); + generator.writeBooleanField(SCAN_ALL_FILES, scanAllFiles); + generator.writeEndObject(); + generator.flush(); + + } catch (IOException e) { + throw new UncheckedIOException("Failed to write StreamingOffset to json", e); + } + + return writer.toString(); + } + + long snapshotId() { + return snapshotId; + } + + long position() { + return position; + } + + boolean shouldScanAllFiles() { + return scanAllFiles; + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof StreamingOffset) { + StreamingOffset offset = (StreamingOffset) obj; + return offset.snapshotId == snapshotId + && offset.position == position + && offset.scanAllFiles == scanAllFiles; + } else { + return false; + } + } + + @Override + public int hashCode() { + return Objects.hashCode(snapshotId, position, scanAllFiles); + } + + @Override + public String toString() { + return String.format( + "Streaming Offset[%d: position (%d) scan_all_files (%b)]", + snapshotId, position, scanAllFiles); + } + + private static StreamingOffset fromJsonNode(JsonNode node) { + // The version of StreamingOffset. The offset was created with a version number + // used to validate when deserializing from json string. + int version = JsonUtil.getInt(VERSION, node); + Preconditions.checkArgument( + version == CURR_VERSION, + "This version of Iceberg source only supports version %s. Version %s is not supported.", + CURR_VERSION, + version); + + long snapshotId = JsonUtil.getLong(SNAPSHOT_ID, node); + int position = JsonUtil.getInt(POSITION, node); + boolean shouldScanAllFiles = JsonUtil.getBool(SCAN_ALL_FILES, node); + + return new StreamingOffset(snapshotId, position, shouldScanAllFiles); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/StructInternalRow.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/StructInternalRow.java new file mode 100644 index 000000000000..25a349b891bd --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/StructInternalRow.java @@ -0,0 +1,385 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.time.Duration; +import java.time.Instant; +import java.time.LocalDate; +import java.time.OffsetDateTime; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.function.BiConsumer; +import java.util.function.Function; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.ByteBuffers; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.GenericArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.BooleanType; +import org.apache.spark.sql.types.ByteType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.FloatType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.unsafe.types.VariantVal; + +class StructInternalRow extends InternalRow { + private final Types.StructType type; + private StructLike struct; + + StructInternalRow(Types.StructType type) { + this.type = type; + } + + private StructInternalRow(Types.StructType type, StructLike struct) { + this.type = type; + this.struct = struct; + } + + public StructInternalRow setStruct(StructLike newStruct) { + this.struct = newStruct; + return this; + } + + @Override + public int numFields() { + return struct.size(); + } + + @Override + public void setNullAt(int i) { + throw new UnsupportedOperationException("StructInternalRow is read-only"); + } + + @Override + public void update(int i, Object value) { + throw new UnsupportedOperationException("StructInternalRow is read-only"); + } + + @Override + public InternalRow copy() { + return this; + } + + @Override + public boolean isNullAt(int ordinal) { + return struct.get(ordinal, Object.class) == null; + } + + @Override + public boolean getBoolean(int ordinal) { + return struct.get(ordinal, Boolean.class); + } + + @Override + public byte getByte(int ordinal) { + return (byte) (int) struct.get(ordinal, Integer.class); + } + + @Override + public short getShort(int ordinal) { + return (short) (int) struct.get(ordinal, Integer.class); + } + + @Override + public int getInt(int ordinal) { + Object integer = struct.get(ordinal, Object.class); + + if (integer instanceof Integer) { + return (int) integer; + } else if (integer instanceof LocalDate) { + return (int) ((LocalDate) integer).toEpochDay(); + } else { + throw new IllegalStateException( + "Unknown type for int field. Type name: " + integer.getClass().getName()); + } + } + + @Override + public long getLong(int ordinal) { + Object longVal = struct.get(ordinal, Object.class); + + if (longVal instanceof Long) { + return (long) longVal; + } else if (longVal instanceof OffsetDateTime) { + return Duration.between(Instant.EPOCH, (OffsetDateTime) longVal).toNanos() / 1000; + } else if (longVal instanceof LocalDate) { + return ((LocalDate) longVal).toEpochDay(); + } else { + throw new IllegalStateException( + "Unknown type for long field. Type name: " + longVal.getClass().getName()); + } + } + + @Override + public float getFloat(int ordinal) { + return struct.get(ordinal, Float.class); + } + + @Override + public double getDouble(int ordinal) { + return struct.get(ordinal, Double.class); + } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + return isNullAt(ordinal) ? null : getDecimalInternal(ordinal, precision, scale); + } + + private Decimal getDecimalInternal(int ordinal, int precision, int scale) { + return Decimal.apply(struct.get(ordinal, BigDecimal.class)); + } + + @Override + public UTF8String getUTF8String(int ordinal) { + return isNullAt(ordinal) ? null : getUTF8StringInternal(ordinal); + } + + private UTF8String getUTF8StringInternal(int ordinal) { + CharSequence seq = struct.get(ordinal, CharSequence.class); + return UTF8String.fromString(seq.toString()); + } + + @Override + public byte[] getBinary(int ordinal) { + return isNullAt(ordinal) ? null : getBinaryInternal(ordinal); + } + + private byte[] getBinaryInternal(int ordinal) { + Object bytes = struct.get(ordinal, Object.class); + + // should only be either ByteBuffer or byte[] + if (bytes instanceof ByteBuffer) { + return ByteBuffers.toByteArray((ByteBuffer) bytes); + } else if (bytes instanceof byte[]) { + return (byte[]) bytes; + } else { + throw new IllegalStateException( + "Unknown type for binary field. Type name: " + bytes.getClass().getName()); + } + } + + @Override + public CalendarInterval getInterval(int ordinal) { + throw new UnsupportedOperationException("Unsupported type: interval"); + } + + @Override + public InternalRow getStruct(int ordinal, int numFields) { + return isNullAt(ordinal) ? null : getStructInternal(ordinal, numFields); + } + + private InternalRow getStructInternal(int ordinal, int numFields) { + return new StructInternalRow( + type.fields().get(ordinal).type().asStructType(), struct.get(ordinal, StructLike.class)); + } + + @Override + public ArrayData getArray(int ordinal) { + return isNullAt(ordinal) ? null : getArrayInternal(ordinal); + } + + private ArrayData getArrayInternal(int ordinal) { + return collectionToArrayData( + type.fields().get(ordinal).type().asListType().elementType(), + struct.get(ordinal, Collection.class)); + } + + @Override + public MapData getMap(int ordinal) { + return isNullAt(ordinal) ? null : getMapInternal(ordinal); + } + + @Override + public VariantVal getVariant(int ordinal) { + throw new UnsupportedOperationException("Unsupported method: getVariant"); + } + + private MapData getMapInternal(int ordinal) { + return mapToMapData( + type.fields().get(ordinal).type().asMapType(), struct.get(ordinal, Map.class)); + } + + @Override + @SuppressWarnings("checkstyle:CyclomaticComplexity") + public Object get(int ordinal, DataType dataType) { + if (isNullAt(ordinal)) { + return null; + } + + if (dataType instanceof IntegerType) { + return getInt(ordinal); + } else if (dataType instanceof LongType) { + return getLong(ordinal); + } else if (dataType instanceof StringType) { + return getUTF8StringInternal(ordinal); + } else if (dataType instanceof FloatType) { + return getFloat(ordinal); + } else if (dataType instanceof DoubleType) { + return getDouble(ordinal); + } else if (dataType instanceof DecimalType) { + DecimalType decimalType = (DecimalType) dataType; + return getDecimalInternal(ordinal, decimalType.precision(), decimalType.scale()); + } else if (dataType instanceof BinaryType) { + return getBinaryInternal(ordinal); + } else if (dataType instanceof StructType) { + return getStructInternal(ordinal, ((StructType) dataType).size()); + } else if (dataType instanceof ArrayType) { + return getArrayInternal(ordinal); + } else if (dataType instanceof MapType) { + return getMapInternal(ordinal); + } else if (dataType instanceof BooleanType) { + return getBoolean(ordinal); + } else if (dataType instanceof ByteType) { + return getByte(ordinal); + } else if (dataType instanceof ShortType) { + return getShort(ordinal); + } else if (dataType instanceof DateType) { + return getInt(ordinal); + } else if (dataType instanceof TimestampType) { + return getLong(ordinal); + } + return null; + } + + private MapData mapToMapData(Types.MapType mapType, Map map) { + // make a defensive copy to ensure entries do not change + List> entries = ImmutableList.copyOf(map.entrySet()); + return new ArrayBasedMapData( + collectionToArrayData(mapType.keyType(), Lists.transform(entries, Map.Entry::getKey)), + collectionToArrayData(mapType.valueType(), Lists.transform(entries, Map.Entry::getValue))); + } + + private ArrayData collectionToArrayData(Type elementType, Collection values) { + switch (elementType.typeId()) { + case BOOLEAN: + case INTEGER: + case DATE: + case TIME: + case LONG: + case TIMESTAMP: + case FLOAT: + case DOUBLE: + return fillArray(values, array -> (pos, value) -> array[pos] = value); + case STRING: + return fillArray( + values, + array -> + (BiConsumer) + (pos, seq) -> array[pos] = UTF8String.fromString(seq.toString())); + case FIXED: + case BINARY: + return fillArray( + values, + array -> + (BiConsumer) + (pos, buf) -> array[pos] = ByteBuffers.toByteArray(buf)); + case DECIMAL: + return fillArray( + values, + array -> + (BiConsumer) (pos, dec) -> array[pos] = Decimal.apply(dec)); + case STRUCT: + return fillArray( + values, + array -> + (BiConsumer) + (pos, tuple) -> + array[pos] = new StructInternalRow(elementType.asStructType(), tuple)); + case LIST: + return fillArray( + values, + array -> + (BiConsumer>) + (pos, list) -> + array[pos] = + collectionToArrayData(elementType.asListType().elementType(), list)); + case MAP: + return fillArray( + values, + array -> + (BiConsumer>) + (pos, map) -> array[pos] = mapToMapData(elementType.asMapType(), map)); + default: + throw new UnsupportedOperationException("Unsupported array element type: " + elementType); + } + } + + @SuppressWarnings("unchecked") + private GenericArrayData fillArray( + Collection values, Function> makeSetter) { + Object[] array = new Object[values.size()]; + BiConsumer setter = makeSetter.apply(array); + + int index = 0; + for (Object value : values) { + if (value == null) { + array[index] = null; + } else { + setter.accept(index, (T) value); + } + + index += 1; + } + + return new GenericArrayData(array); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + + if (other == null || getClass() != other.getClass()) { + return false; + } + + StructInternalRow that = (StructInternalRow) other; + return type.equals(that.type) && struct.equals(that.struct); + } + + @Override + public int hashCode() { + return Objects.hash(type, struct); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/EqualityDeleteFiles.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/EqualityDeleteFiles.java new file mode 100644 index 000000000000..754145f7d252 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/EqualityDeleteFiles.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.spark.sql.connector.metric.CustomSumMetric; + +public class EqualityDeleteFiles extends CustomSumMetric { + + static final String NAME = "equalityDeleteFiles"; + + @Override + public String name() { + return NAME; + } + + @Override + public String description() { + return "number of equality delete files"; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/IndexedDeleteFiles.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/IndexedDeleteFiles.java new file mode 100644 index 000000000000..7fc5b9066cdc --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/IndexedDeleteFiles.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.spark.sql.connector.metric.CustomSumMetric; + +public class IndexedDeleteFiles extends CustomSumMetric { + + static final String NAME = "indexedDeleteFiles"; + + @Override + public String name() { + return NAME; + } + + @Override + public String description() { + return "number of indexed delete files"; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/NumDeletes.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/NumDeletes.java new file mode 100644 index 000000000000..000499874ba5 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/NumDeletes.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import java.text.NumberFormat; +import org.apache.spark.sql.connector.metric.CustomMetric; + +public class NumDeletes implements CustomMetric { + + public static final String DISPLAY_STRING = "number of row deletes applied"; + + @Override + public String name() { + return "numDeletes"; + } + + @Override + public String description() { + return DISPLAY_STRING; + } + + @Override + public String aggregateTaskMetrics(long[] taskMetrics) { + long sum = initialValue; + for (long taskMetric : taskMetrics) { + sum += taskMetric; + } + + return NumberFormat.getIntegerInstance().format(sum); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/NumSplits.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/NumSplits.java new file mode 100644 index 000000000000..41d7c1e8db71 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/NumSplits.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import java.text.NumberFormat; +import org.apache.spark.sql.connector.metric.CustomMetric; + +public class NumSplits implements CustomMetric { + + @Override + public String name() { + return "numSplits"; + } + + @Override + public String description() { + return "number of file splits read"; + } + + @Override + public String aggregateTaskMetrics(long[] taskMetrics) { + long sum = initialValue; + for (long taskMetric : taskMetrics) { + sum += taskMetric; + } + + return NumberFormat.getIntegerInstance().format(sum); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/PositionalDeleteFiles.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/PositionalDeleteFiles.java new file mode 100644 index 000000000000..5de75776ea4f --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/PositionalDeleteFiles.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.spark.sql.connector.metric.CustomSumMetric; + +public class PositionalDeleteFiles extends CustomSumMetric { + + static final String NAME = "positionalDeleteFiles"; + + @Override + public String name() { + return NAME; + } + + @Override + public String description() { + return "number of positional delete files"; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/ResultDataFiles.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/ResultDataFiles.java new file mode 100644 index 000000000000..21959cbf6c63 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/ResultDataFiles.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.spark.sql.connector.metric.CustomSumMetric; + +public class ResultDataFiles extends CustomSumMetric { + + static final String NAME = "resultDataFiles"; + + @Override + public String name() { + return NAME; + } + + @Override + public String description() { + return "number of result data files"; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/ResultDeleteFiles.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/ResultDeleteFiles.java new file mode 100644 index 000000000000..9c6ad2ca328a --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/ResultDeleteFiles.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.spark.sql.connector.metric.CustomSumMetric; + +public class ResultDeleteFiles extends CustomSumMetric { + + static final String NAME = "resultDeleteFiles"; + + @Override + public String name() { + return NAME; + } + + @Override + public String description() { + return "number of result delete files"; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/ScannedDataManifests.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/ScannedDataManifests.java new file mode 100644 index 000000000000..a167904280e6 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/ScannedDataManifests.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.spark.sql.connector.metric.CustomSumMetric; + +public class ScannedDataManifests extends CustomSumMetric { + + static final String NAME = "scannedDataManifests"; + + @Override + public String name() { + return NAME; + } + + @Override + public String description() { + return "number of scanned data manifests"; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/ScannedDeleteManifests.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/ScannedDeleteManifests.java new file mode 100644 index 000000000000..1fa006b7b193 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/ScannedDeleteManifests.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.spark.sql.connector.metric.CustomSumMetric; + +public class ScannedDeleteManifests extends CustomSumMetric { + + static final String NAME = "scannedDeleteManifests"; + + @Override + public String name() { + return NAME; + } + + @Override + public String description() { + return "number of scanned delete manifests"; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/SkippedDataFiles.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/SkippedDataFiles.java new file mode 100644 index 000000000000..7fd17425313d --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/SkippedDataFiles.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.spark.sql.connector.metric.CustomSumMetric; + +public class SkippedDataFiles extends CustomSumMetric { + + static final String NAME = "skippedDataFiles"; + + @Override + public String name() { + return NAME; + } + + @Override + public String description() { + return "number of skipped data files"; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/SkippedDataManifests.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/SkippedDataManifests.java new file mode 100644 index 000000000000..b0eaeb5d87f2 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/SkippedDataManifests.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.spark.sql.connector.metric.CustomSumMetric; + +public class SkippedDataManifests extends CustomSumMetric { + + static final String NAME = "skippedDataManifests"; + + @Override + public String name() { + return NAME; + } + + @Override + public String description() { + return "number of skipped data manifests"; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/SkippedDeleteFiles.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/SkippedDeleteFiles.java new file mode 100644 index 000000000000..70597be67113 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/SkippedDeleteFiles.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.spark.sql.connector.metric.CustomSumMetric; + +public class SkippedDeleteFiles extends CustomSumMetric { + + static final String NAME = "skippedDeleteFiles"; + + @Override + public String name() { + return NAME; + } + + @Override + public String description() { + return "number of skipped delete files"; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/SkippedDeleteManifests.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/SkippedDeleteManifests.java new file mode 100644 index 000000000000..0336170b45a1 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/SkippedDeleteManifests.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.spark.sql.connector.metric.CustomSumMetric; + +public class SkippedDeleteManifests extends CustomSumMetric { + + static final String NAME = "skippedDeleteManifests"; + + @Override + public String name() { + return NAME; + } + + @Override + public String description() { + return "number of skipped delete manifest"; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskEqualityDeleteFiles.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskEqualityDeleteFiles.java new file mode 100644 index 000000000000..ecd14bcca31d --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskEqualityDeleteFiles.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.iceberg.metrics.CounterResult; +import org.apache.iceberg.metrics.ScanReport; +import org.apache.spark.sql.connector.metric.CustomTaskMetric; + +public class TaskEqualityDeleteFiles implements CustomTaskMetric { + private final long value; + + private TaskEqualityDeleteFiles(long value) { + this.value = value; + } + + @Override + public String name() { + return EqualityDeleteFiles.NAME; + } + + @Override + public long value() { + return value; + } + + public static TaskEqualityDeleteFiles from(ScanReport scanReport) { + CounterResult counter = scanReport.scanMetrics().equalityDeleteFiles(); + long value = counter != null ? counter.value() : 0L; + return new TaskEqualityDeleteFiles(value); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskIndexedDeleteFiles.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskIndexedDeleteFiles.java new file mode 100644 index 000000000000..63b6767e955d --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskIndexedDeleteFiles.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.iceberg.metrics.CounterResult; +import org.apache.iceberg.metrics.ScanReport; +import org.apache.spark.sql.connector.metric.CustomTaskMetric; + +public class TaskIndexedDeleteFiles implements CustomTaskMetric { + private final long value; + + private TaskIndexedDeleteFiles(long value) { + this.value = value; + } + + @Override + public String name() { + return IndexedDeleteFiles.NAME; + } + + @Override + public long value() { + return value; + } + + public static TaskIndexedDeleteFiles from(ScanReport scanReport) { + CounterResult counter = scanReport.scanMetrics().indexedDeleteFiles(); + long value = counter != null ? counter.value() : 0L; + return new TaskIndexedDeleteFiles(value); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskNumDeletes.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskNumDeletes.java new file mode 100644 index 000000000000..8c734ba9f022 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskNumDeletes.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.spark.sql.connector.metric.CustomTaskMetric; + +public class TaskNumDeletes implements CustomTaskMetric { + private final long value; + + public TaskNumDeletes(long value) { + this.value = value; + } + + @Override + public String name() { + return "numDeletes"; + } + + @Override + public long value() { + return value; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskNumSplits.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskNumSplits.java new file mode 100644 index 000000000000..d8cbc4db05bb --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskNumSplits.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.spark.sql.connector.metric.CustomTaskMetric; + +public class TaskNumSplits implements CustomTaskMetric { + private final long value; + + public TaskNumSplits(long value) { + this.value = value; + } + + @Override + public String name() { + return "numSplits"; + } + + @Override + public long value() { + return value; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskPositionalDeleteFiles.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskPositionalDeleteFiles.java new file mode 100644 index 000000000000..805f22bf0d7c --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskPositionalDeleteFiles.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.iceberg.metrics.CounterResult; +import org.apache.iceberg.metrics.ScanReport; +import org.apache.spark.sql.connector.metric.CustomTaskMetric; + +public class TaskPositionalDeleteFiles implements CustomTaskMetric { + private final long value; + + private TaskPositionalDeleteFiles(long value) { + this.value = value; + } + + @Override + public String name() { + return PositionalDeleteFiles.NAME; + } + + @Override + public long value() { + return value; + } + + public static TaskPositionalDeleteFiles from(ScanReport scanReport) { + CounterResult counter = scanReport.scanMetrics().positionalDeleteFiles(); + long value = counter != null ? counter.value() : 0L; + return new TaskPositionalDeleteFiles(value); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskResultDataFiles.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskResultDataFiles.java new file mode 100644 index 000000000000..a27142131403 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskResultDataFiles.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.iceberg.metrics.CounterResult; +import org.apache.iceberg.metrics.ScanReport; +import org.apache.spark.sql.connector.metric.CustomTaskMetric; + +public class TaskResultDataFiles implements CustomTaskMetric { + private final long value; + + private TaskResultDataFiles(long value) { + this.value = value; + } + + @Override + public String name() { + return ResultDataFiles.NAME; + } + + @Override + public long value() { + return value; + } + + public static TaskResultDataFiles from(ScanReport scanReport) { + CounterResult counter = scanReport.scanMetrics().resultDataFiles(); + long value = counter != null ? counter.value() : 0L; + return new TaskResultDataFiles(value); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskResultDeleteFiles.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskResultDeleteFiles.java new file mode 100644 index 000000000000..aea8ca07dd05 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskResultDeleteFiles.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.iceberg.metrics.CounterResult; +import org.apache.iceberg.metrics.ScanReport; +import org.apache.spark.sql.connector.metric.CustomTaskMetric; + +public class TaskResultDeleteFiles implements CustomTaskMetric { + private final long value; + + private TaskResultDeleteFiles(long value) { + this.value = value; + } + + @Override + public String name() { + return ResultDeleteFiles.NAME; + } + + @Override + public long value() { + return value; + } + + public static TaskResultDeleteFiles from(ScanReport scanReport) { + CounterResult counter = scanReport.scanMetrics().resultDeleteFiles(); + long value = counter != null ? counter.value() : 0L; + return new TaskResultDeleteFiles(value); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskScannedDataManifests.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskScannedDataManifests.java new file mode 100644 index 000000000000..09dd0339910c --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskScannedDataManifests.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.iceberg.metrics.CounterResult; +import org.apache.iceberg.metrics.ScanReport; +import org.apache.spark.sql.connector.metric.CustomTaskMetric; + +public class TaskScannedDataManifests implements CustomTaskMetric { + private final long value; + + private TaskScannedDataManifests(long value) { + this.value = value; + } + + @Override + public String name() { + return ScannedDataManifests.NAME; + } + + @Override + public long value() { + return value; + } + + public static TaskScannedDataManifests from(ScanReport scanReport) { + CounterResult counter = scanReport.scanMetrics().scannedDataManifests(); + long value = counter != null ? counter.value() : 0L; + return new TaskScannedDataManifests(value); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskScannedDeleteManifests.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskScannedDeleteManifests.java new file mode 100644 index 000000000000..1766cf2f6835 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskScannedDeleteManifests.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.iceberg.metrics.CounterResult; +import org.apache.iceberg.metrics.ScanReport; +import org.apache.spark.sql.connector.metric.CustomTaskMetric; + +public class TaskScannedDeleteManifests implements CustomTaskMetric { + private final long value; + + private TaskScannedDeleteManifests(long value) { + this.value = value; + } + + @Override + public String name() { + return ScannedDeleteManifests.NAME; + } + + @Override + public long value() { + return value; + } + + public static TaskScannedDeleteManifests from(ScanReport scanReport) { + CounterResult counter = scanReport.scanMetrics().scannedDeleteManifests(); + long value = counter != null ? counter.value() : 0L; + return new TaskScannedDeleteManifests(value); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskSkippedDataFiles.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskSkippedDataFiles.java new file mode 100644 index 000000000000..5165f9a3116c --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskSkippedDataFiles.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.iceberg.metrics.CounterResult; +import org.apache.iceberg.metrics.ScanReport; +import org.apache.spark.sql.connector.metric.CustomTaskMetric; + +public class TaskSkippedDataFiles implements CustomTaskMetric { + private final long value; + + private TaskSkippedDataFiles(long value) { + this.value = value; + } + + @Override + public String name() { + return SkippedDataFiles.NAME; + } + + @Override + public long value() { + return value; + } + + public static TaskSkippedDataFiles from(ScanReport scanReport) { + CounterResult counter = scanReport.scanMetrics().skippedDataFiles(); + long value = counter != null ? counter.value() : 0L; + return new TaskSkippedDataFiles(value); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskSkippedDataManifests.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskSkippedDataManifests.java new file mode 100644 index 000000000000..86fef8c4118b --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskSkippedDataManifests.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.iceberg.metrics.CounterResult; +import org.apache.iceberg.metrics.ScanReport; +import org.apache.spark.sql.connector.metric.CustomTaskMetric; + +public class TaskSkippedDataManifests implements CustomTaskMetric { + private final long value; + + private TaskSkippedDataManifests(long value) { + this.value = value; + } + + @Override + public String name() { + return SkippedDataManifests.NAME; + } + + @Override + public long value() { + return value; + } + + public static TaskSkippedDataManifests from(ScanReport scanReport) { + CounterResult counter = scanReport.scanMetrics().skippedDataManifests(); + long value = counter != null ? counter.value() : 0L; + return new TaskSkippedDataManifests(value); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskSkippedDeleteFiles.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskSkippedDeleteFiles.java new file mode 100644 index 000000000000..87579751742c --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskSkippedDeleteFiles.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.iceberg.metrics.CounterResult; +import org.apache.iceberg.metrics.ScanReport; +import org.apache.spark.sql.connector.metric.CustomTaskMetric; + +public class TaskSkippedDeleteFiles implements CustomTaskMetric { + private final long value; + + private TaskSkippedDeleteFiles(long value) { + this.value = value; + } + + @Override + public String name() { + return SkippedDeleteFiles.NAME; + } + + @Override + public long value() { + return value; + } + + public static TaskSkippedDeleteFiles from(ScanReport scanReport) { + CounterResult counter = scanReport.scanMetrics().skippedDeleteFiles(); + long value = counter != null ? counter.value() : 0L; + return new TaskSkippedDeleteFiles(value); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskSkippedDeleteManifests.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskSkippedDeleteManifests.java new file mode 100644 index 000000000000..4a9c71e0c1e4 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskSkippedDeleteManifests.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.iceberg.metrics.CounterResult; +import org.apache.iceberg.metrics.ScanReport; +import org.apache.spark.sql.connector.metric.CustomTaskMetric; + +public class TaskSkippedDeleteManifests implements CustomTaskMetric { + private final long value; + + private TaskSkippedDeleteManifests(long value) { + this.value = value; + } + + @Override + public String name() { + return SkippedDeleteManifests.NAME; + } + + @Override + public long value() { + return value; + } + + public static TaskSkippedDeleteManifests from(ScanReport scanReport) { + CounterResult counter = scanReport.scanMetrics().skippedDeleteManifests(); + long value = counter != null ? counter.value() : 0L; + return new TaskSkippedDeleteManifests(value); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskTotalDataFileSize.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskTotalDataFileSize.java new file mode 100644 index 000000000000..3f5a224425d8 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskTotalDataFileSize.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.iceberg.metrics.CounterResult; +import org.apache.iceberg.metrics.ScanReport; +import org.apache.spark.sql.connector.metric.CustomTaskMetric; + +public class TaskTotalDataFileSize implements CustomTaskMetric { + + private final long value; + + private TaskTotalDataFileSize(long value) { + this.value = value; + } + + @Override + public String name() { + return TotalDataFileSize.NAME; + } + + @Override + public long value() { + return value; + } + + public static TaskTotalDataFileSize from(ScanReport scanReport) { + CounterResult counter = scanReport.scanMetrics().totalFileSizeInBytes(); + long value = counter != null ? counter.value() : 0L; + return new TaskTotalDataFileSize(value); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskTotalDataManifests.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskTotalDataManifests.java new file mode 100644 index 000000000000..6d8c3c24e460 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskTotalDataManifests.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.iceberg.metrics.CounterResult; +import org.apache.iceberg.metrics.ScanReport; +import org.apache.spark.sql.connector.metric.CustomTaskMetric; + +public class TaskTotalDataManifests implements CustomTaskMetric { + private final long value; + + private TaskTotalDataManifests(long value) { + this.value = value; + } + + @Override + public String name() { + return TotalDataManifests.NAME; + } + + @Override + public long value() { + return value; + } + + public static TaskTotalDataManifests from(ScanReport scanReport) { + CounterResult counter = scanReport.scanMetrics().totalDataManifests(); + long value = counter != null ? counter.value() : 0L; + return new TaskTotalDataManifests(value); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskTotalDeleteFileSize.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskTotalDeleteFileSize.java new file mode 100644 index 000000000000..17ecec78da3f --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskTotalDeleteFileSize.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.iceberg.metrics.CounterResult; +import org.apache.iceberg.metrics.ScanReport; +import org.apache.spark.sql.connector.metric.CustomTaskMetric; + +public class TaskTotalDeleteFileSize implements CustomTaskMetric { + + private final long value; + + private TaskTotalDeleteFileSize(long value) { + this.value = value; + } + + @Override + public String name() { + return TotalDeleteFileSize.NAME; + } + + @Override + public long value() { + return value; + } + + public static TaskTotalDeleteFileSize from(ScanReport scanReport) { + CounterResult counter = scanReport.scanMetrics().totalDeleteFileSizeInBytes(); + long value = counter != null ? counter.value() : 0L; + return new TaskTotalDeleteFileSize(value); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskTotalDeleteManifests.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskTotalDeleteManifests.java new file mode 100644 index 000000000000..ff55c1be89e3 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskTotalDeleteManifests.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.iceberg.metrics.CounterResult; +import org.apache.iceberg.metrics.ScanReport; +import org.apache.spark.sql.connector.metric.CustomTaskMetric; + +public class TaskTotalDeleteManifests implements CustomTaskMetric { + private final long value; + + private TaskTotalDeleteManifests(long value) { + this.value = value; + } + + @Override + public String name() { + return TotalDeleteManifests.NAME; + } + + @Override + public long value() { + return value; + } + + public static TaskTotalDeleteManifests from(ScanReport scanReport) { + CounterResult counter = scanReport.scanMetrics().totalDeleteManifests(); + long value = counter != null ? counter.value() : 0L; + return new TaskTotalDeleteManifests(value); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskTotalPlanningDuration.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskTotalPlanningDuration.java new file mode 100644 index 000000000000..32ac6fde8bf3 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskTotalPlanningDuration.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.iceberg.metrics.ScanReport; +import org.apache.iceberg.metrics.TimerResult; +import org.apache.spark.sql.connector.metric.CustomTaskMetric; + +public class TaskTotalPlanningDuration implements CustomTaskMetric { + + private final long value; + + private TaskTotalPlanningDuration(long value) { + this.value = value; + } + + @Override + public String name() { + return TotalPlanningDuration.NAME; + } + + @Override + public long value() { + return value; + } + + public static TaskTotalPlanningDuration from(ScanReport scanReport) { + TimerResult timerResult = scanReport.scanMetrics().totalPlanningDuration(); + long value = timerResult != null ? timerResult.totalDuration().toMillis() : -1; + return new TaskTotalPlanningDuration(value); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TotalDataFileSize.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TotalDataFileSize.java new file mode 100644 index 000000000000..b1ff8a46368c --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TotalDataFileSize.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.spark.sql.connector.metric.CustomSumMetric; + +public class TotalDataFileSize extends CustomSumMetric { + + static final String NAME = "totalDataFileSize"; + + @Override + public String name() { + return NAME; + } + + @Override + public String description() { + return "total data file size (bytes)"; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TotalDataManifests.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TotalDataManifests.java new file mode 100644 index 000000000000..de8f04be7767 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TotalDataManifests.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.spark.sql.connector.metric.CustomSumMetric; + +public class TotalDataManifests extends CustomSumMetric { + + static final String NAME = "totalDataManifest"; + + @Override + public String name() { + return NAME; + } + + @Override + public String description() { + return "total data manifests"; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TotalDeleteFileSize.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TotalDeleteFileSize.java new file mode 100644 index 000000000000..da4303325273 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TotalDeleteFileSize.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.spark.sql.connector.metric.CustomSumMetric; + +public class TotalDeleteFileSize extends CustomSumMetric { + + static final String NAME = "totalDeleteFileSize"; + + @Override + public String name() { + return NAME; + } + + @Override + public String description() { + return "total delete file size (bytes)"; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TotalDeleteManifests.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TotalDeleteManifests.java new file mode 100644 index 000000000000..7442dfdb6ffb --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TotalDeleteManifests.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.spark.sql.connector.metric.CustomSumMetric; + +public class TotalDeleteManifests extends CustomSumMetric { + + static final String NAME = "totalDeleteManifests"; + + @Override + public String name() { + return NAME; + } + + @Override + public String description() { + return "total delete manifests"; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TotalPlanningDuration.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TotalPlanningDuration.java new file mode 100644 index 000000000000..8b66eeac4046 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TotalPlanningDuration.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.spark.sql.connector.metric.CustomSumMetric; + +public class TotalPlanningDuration extends CustomSumMetric { + + static final String NAME = "totalPlanningDuration"; + + @Override + public String name() { + return NAME; + } + + @Override + public String description() { + return "total planning duration (ms)"; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/spark/sql/catalyst/analysis/IcebergAnalysisException.java b/spark/v4.0/spark/src/main/java/org/apache/spark/sql/catalyst/analysis/IcebergAnalysisException.java new file mode 100644 index 000000000000..1953d7986632 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/spark/sql/catalyst/analysis/IcebergAnalysisException.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.catalyst.analysis; + +import org.apache.spark.QueryContext; +import org.apache.spark.sql.AnalysisException; +import scala.Option; +import scala.collection.immutable.Map$; + +public class IcebergAnalysisException extends AnalysisException { + public IcebergAnalysisException(String message) { + super( + message, + Option.empty(), + Option.empty(), + Option.empty(), + Option.empty(), + Map$.MODULE$.empty(), + new QueryContext[0]); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/spark/sql/catalyst/analysis/NoSuchProcedureException.java b/spark/v4.0/spark/src/main/java/org/apache/spark/sql/catalyst/analysis/NoSuchProcedureException.java new file mode 100644 index 000000000000..9ed7167c94c6 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/spark/sql/catalyst/analysis/NoSuchProcedureException.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.catalyst.analysis; + +import org.apache.spark.QueryContext; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.connector.catalog.Identifier; +import scala.Option; +import scala.collection.immutable.Map$; + +public class NoSuchProcedureException extends AnalysisException { + public NoSuchProcedureException(Identifier ident) { + super( + "Procedure " + ident + " not found", + Option.empty(), + Option.empty(), + Option.empty(), + Option.empty(), + Map$.MODULE$.empty(), + new QueryContext[0]); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/Procedure.java b/spark/v4.0/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/Procedure.java new file mode 100644 index 000000000000..11f215ba040a --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/Procedure.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.connector.iceberg.catalog; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructType; + +/** An interface representing a stored procedure available for execution. */ +public interface Procedure { + /** Returns the input parameters of this procedure. */ + ProcedureParameter[] parameters(); + + /** Returns the type of rows produced by this procedure. */ + StructType outputType(); + + /** + * Executes this procedure. + * + *

Spark will align the provided arguments according to the input parameters defined in {@link + * #parameters()} either by position or by name before execution. + * + *

Implementations may provide a summary of execution by returning one or many rows as a + * result. The schema of output rows must match the defined output type in {@link #outputType()}. + * + * @param args input arguments + * @return the result of executing this procedure with the given arguments + */ + InternalRow[] call(InternalRow args); + + /** Returns the description of this procedure. */ + default String description() { + return this.getClass().toString(); + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/ProcedureCatalog.java b/spark/v4.0/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/ProcedureCatalog.java new file mode 100644 index 000000000000..2cee97ee5938 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/ProcedureCatalog.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.connector.iceberg.catalog; + +import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.Identifier; + +/** + * A catalog API for working with stored procedures. + * + *

Implementations should implement this interface if they expose stored procedures that can be + * called via CALL statements. + */ +public interface ProcedureCatalog extends CatalogPlugin { + /** + * Load a {@link Procedure stored procedure} by {@link Identifier identifier}. + * + * @param ident a stored procedure identifier + * @return the stored procedure's metadata + * @throws NoSuchProcedureException if there is no matching stored procedure + */ + Procedure loadProcedure(Identifier ident) throws NoSuchProcedureException; +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/ProcedureParameter.java b/spark/v4.0/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/ProcedureParameter.java new file mode 100644 index 000000000000..e1e84b2597f3 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/ProcedureParameter.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.connector.iceberg.catalog; + +import org.apache.spark.sql.types.DataType; + +/** An input parameter of a {@link Procedure stored procedure}. */ +public interface ProcedureParameter { + + /** + * Creates a required input parameter. + * + * @param name the name of the parameter + * @param dataType the type of the parameter + * @return the constructed stored procedure parameter + */ + static ProcedureParameter required(String name, DataType dataType) { + return new ProcedureParameterImpl(name, dataType, true); + } + + /** + * Creates an optional input parameter. + * + * @param name the name of the parameter. + * @param dataType the type of the parameter. + * @return the constructed optional stored procedure parameter + */ + static ProcedureParameter optional(String name, DataType dataType) { + return new ProcedureParameterImpl(name, dataType, false); + } + + /** Returns the name of this parameter. */ + String name(); + + /** Returns the type of this parameter. */ + DataType dataType(); + + /** Returns true if this parameter is required. */ + boolean required(); +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/ProcedureParameterImpl.java b/spark/v4.0/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/ProcedureParameterImpl.java new file mode 100644 index 000000000000..c59951e24330 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/ProcedureParameterImpl.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.connector.iceberg.catalog; + +import java.util.Objects; +import org.apache.spark.sql.types.DataType; + +/** A {@link ProcedureParameter} implementation. */ +class ProcedureParameterImpl implements ProcedureParameter { + private final String name; + private final DataType dataType; + private final boolean required; + + ProcedureParameterImpl(String name, DataType dataType, boolean required) { + this.name = name; + this.dataType = dataType; + this.required = required; + } + + @Override + public String name() { + return name; + } + + @Override + public DataType dataType() { + return dataType; + } + + @Override + public boolean required() { + return required; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } else if (other == null || getClass() != other.getClass()) { + return false; + } + + ProcedureParameterImpl that = (ProcedureParameterImpl) other; + return required == that.required + && Objects.equals(name, that.name) + && Objects.equals(dataType, that.dataType); + } + + @Override + public int hashCode() { + return Objects.hash(name, dataType, required); + } + + @Override + public String toString() { + return String.format( + "ProcedureParameter(name='%s', type=%s, required=%b)", name, dataType, required); + } +} diff --git a/spark/v4.0/spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/spark/v4.0/spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 000000000000..01a6c4e0670d --- /dev/null +++ b/spark/v4.0/spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +org.apache.iceberg.spark.source.IcebergSource diff --git a/spark/v4.0/spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/OrderAwareCoalesce.scala b/spark/v4.0/spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/OrderAwareCoalesce.scala new file mode 100644 index 000000000000..5acaa6800e68 --- /dev/null +++ b/spark/v4.0/spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/OrderAwareCoalesce.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.rdd.PartitionCoalescer +import org.apache.spark.rdd.PartitionGroup +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.Attribute + +// this node doesn't extend RepartitionOperation on purpose to keep this logic isolated +// and ignore it in optimizer rules such as CollapseRepartition +case class OrderAwareCoalesce( + numPartitions: Int, + coalescer: PartitionCoalescer, + child: LogicalPlan) extends OrderPreservingUnaryNode { + + override def output: Seq[Attribute] = child.output + + override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = { + copy(child = newChild) + } +} + +class OrderAwareCoalescer(val groupSize: Int) extends PartitionCoalescer with Serializable { + + override def coalesce(maxPartitions: Int, parent: RDD[_]): Array[PartitionGroup] = { + val partitionBins = parent.partitions.grouped(groupSize) + partitionBins.map { partitions => + val group = new PartitionGroup() + group.partitions ++= partitions + group + }.toArray + } +} diff --git a/spark/v4.0/spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SetWriteDistributionAndOrdering.scala b/spark/v4.0/spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SetWriteDistributionAndOrdering.scala new file mode 100644 index 000000000000..7b599eb3da1d --- /dev/null +++ b/spark/v4.0/spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SetWriteDistributionAndOrdering.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.iceberg.DistributionMode +import org.apache.iceberg.NullOrder +import org.apache.iceberg.SortDirection +import org.apache.iceberg.expressions.Term +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits + +case class SetWriteDistributionAndOrdering( + table: Seq[String], + distributionMode: Option[DistributionMode], + sortOrder: Seq[(Term, SortDirection, NullOrder)]) extends LeafCommand { + + import CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override def simpleString(maxFields: Int): String = { + val order = sortOrder.map { + case (term, direction, nullOrder) => s"$term $direction $nullOrder" + }.mkString(", ") + s"SetWriteDistributionAndOrdering ${table.quoted} $distributionMode $order" + } +} diff --git a/spark/v4.0/spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SortOrderParserUtil.scala b/spark/v4.0/spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SortOrderParserUtil.scala new file mode 100644 index 000000000000..bf19ef8a2167 --- /dev/null +++ b/spark/v4.0/spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SortOrderParserUtil.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.iceberg.NullOrder +import org.apache.iceberg.Schema +import org.apache.iceberg.SortDirection +import org.apache.iceberg.SortOrder +import org.apache.iceberg.expressions.Term + +class SortOrderParserUtil { + + def collectSortOrder(tableSchema:Schema, sortOrder: Seq[(Term, SortDirection, NullOrder)]): SortOrder = { + val orderBuilder = SortOrder.builderFor(tableSchema) + sortOrder.foreach { + case (term, SortDirection.ASC, nullOrder) => + orderBuilder.asc(term, nullOrder) + case (term, SortDirection.DESC, nullOrder) => + orderBuilder.desc(term, nullOrder) + } + orderBuilder.build(); + } +} diff --git a/spark/v4.0/spark/src/main/scala/org/apache/spark/sql/catalyst/utils/PlanUtils.scala b/spark/v4.0/spark/src/main/scala/org/apache/spark/sql/catalyst/utils/PlanUtils.scala new file mode 100644 index 000000000000..aa9e9c553346 --- /dev/null +++ b/spark/v4.0/spark/src/main/scala/org/apache/spark/sql/catalyst/utils/PlanUtils.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.utils + +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import scala.annotation.tailrec + +object PlanUtils { + @tailrec + def isIcebergRelation(plan: LogicalPlan): Boolean = { + def isIcebergTable(relation: DataSourceV2Relation): Boolean = relation.table match { + case _: SparkTable => true + case _ => false + } + + plan match { + case s: SubqueryAlias => isIcebergRelation(s.child) + case r: DataSourceV2Relation => isIcebergTable(r) + case _ => false + } + } +} diff --git a/spark/v4.0/spark/src/main/scala/org/apache/spark/sql/execution/OrderAwareCoalesceExec.scala b/spark/v4.0/spark/src/main/scala/org/apache/spark/sql/execution/OrderAwareCoalesceExec.scala new file mode 100644 index 000000000000..2ef99550524a --- /dev/null +++ b/spark/v4.0/spark/src/main/scala/org/apache/spark/sql/execution/OrderAwareCoalesceExec.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.rdd.PartitionCoalescer +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.plans.physical.SinglePartition +import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning + +case class OrderAwareCoalesceExec( + numPartitions: Int, + coalescer: PartitionCoalescer, + child: SparkPlan) extends UnaryExecNode { + + override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def outputPartitioning: Partitioning = { + if (numPartitions == 1) SinglePartition else UnknownPartitioning(numPartitions) + } + + protected override def doExecute(): RDD[InternalRow] = { + val result = child.execute() + if (numPartitions == 1 && result.getNumPartitions < 1) { + // make sure we don't output an RDD with 0 partitions, + // when claiming that we have a `SinglePartition` + // see CoalesceExec in Spark + new CoalesceExec.EmptyRDDWithPartitions(sparkContext, numPartitions) + } else { + result.coalesce(numPartitions, shuffle = false, Some(coalescer)) + } + } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = { + copy(child = newChild) + } +} diff --git a/spark/v4.0/spark/src/main/scala/org/apache/spark/sql/execution/datasources/SparkExpressionConverter.scala b/spark/v4.0/spark/src/main/scala/org/apache/spark/sql/execution/datasources/SparkExpressionConverter.scala new file mode 100644 index 000000000000..6bd49cfca5b7 --- /dev/null +++ b/spark/v4.0/spark/src/main/scala/org/apache/spark/sql/execution/datasources/SparkExpressionConverter.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.iceberg.spark.SparkV2Filters +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.analysis.IcebergAnalysisException +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.plans.logical.Filter +import org.apache.spark.sql.catalyst.plans.logical.LeafNode +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Strategy + +object SparkExpressionConverter { + + def convertToIcebergExpression(sparkExpression: Expression): org.apache.iceberg.expressions.Expression = { + // Currently, it is a double conversion as we are converting Spark expression to Spark predicate + // and then converting Spark predicate to Iceberg expression. + // But these two conversions already exist and well tested. So, we are going with this approach. + DataSourceV2Strategy.translateFilterV2(sparkExpression) match { + case Some(filter) => + val converted = SparkV2Filters.convert(filter) + if (converted == null) { + throw new IllegalArgumentException(s"Cannot convert Spark filter: $filter to Iceberg expression") + } + + converted + case _ => + throw new IllegalArgumentException(s"Cannot translate Spark expression: $sparkExpression to data source filter") + } + } + + @throws[IcebergAnalysisException] + def collectResolvedSparkExpression(session: SparkSession, tableName: String, where: String): Expression = { + val tableAttrs = session.table(tableName).queryExecution.analyzed.output + val unresolvedExpression = session.sessionState.sqlParser.parseExpression(where) + val filter = Filter(unresolvedExpression, DummyRelation(tableAttrs)) + val optimizedLogicalPlan = session.sessionState.executePlan(filter).optimizedPlan + optimizedLogicalPlan.collectFirst { + case filter: Filter => filter.condition + case _: DummyRelation => Literal.TrueLiteral + case _: LocalRelation => Literal.FalseLiteral + }.getOrElse(throw new IcebergAnalysisException("Failed to find filter expression")) + } + + case class DummyRelation(output: Seq[Attribute]) extends LeafNode +} diff --git a/spark/v4.0/spark/src/main/scala/org/apache/spark/sql/stats/ThetaSketchAgg.scala b/spark/v4.0/spark/src/main/scala/org/apache/spark/sql/stats/ThetaSketchAgg.scala new file mode 100644 index 000000000000..6995b5e2a49c --- /dev/null +++ b/spark/v4.0/spark/src/main/scala/org/apache/spark/sql/stats/ThetaSketchAgg.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.stats + +import java.nio.ByteBuffer +import org.apache.datasketches.common.Family +import org.apache.datasketches.memory.Memory +import org.apache.datasketches.theta.CompactSketch +import org.apache.datasketches.theta.SetOperationBuilder +import org.apache.datasketches.theta.Sketch +import org.apache.datasketches.theta.UpdateSketch +import org.apache.iceberg.spark.SparkSchemaUtil +import org.apache.iceberg.types.Conversions +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate +import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate +import org.apache.spark.sql.catalyst.trees.UnaryLike +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types.BinaryType +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.Decimal +import org.apache.spark.unsafe.types.UTF8String + +/** + * ThetaSketchAgg generates Alpha family sketch with default seed. + * The values fed to the sketch are converted to bytes using Iceberg's single value serialization. + * The result returned is an array of bytes of Compact Theta sketch of Datasketches library, + * which should be deserialized to Compact sketch before using. + * + * See [[https://iceberg.apache.org/puffin-spec/]] for more information. + * + */ +case class ThetaSketchAgg( + child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[Sketch] with UnaryLike[Expression] { + + private lazy val icebergType = SparkSchemaUtil.convert(child.dataType) + + def this(colName: String) = { + this(analysis.UnresolvedAttribute.quotedString(colName), 0, 0) + } + + override def dataType: DataType = BinaryType + + override def nullable: Boolean = false + + override def createAggregationBuffer(): Sketch = { + UpdateSketch.builder.setFamily(Family.ALPHA).build() + } + + override def update(buffer: Sketch, input: InternalRow): Sketch = { + val value = child.eval(input) + if (value != null) { + val icebergValue = toIcebergValue(value) + val byteBuffer = Conversions.toByteBuffer(icebergType, icebergValue) + buffer.asInstanceOf[UpdateSketch].update(byteBuffer) + } + buffer + } + + private def toIcebergValue(value: Any): Any = { + value match { + case s: UTF8String => s.toString + case d: Decimal => d.toJavaBigDecimal + case b: Array[Byte] => ByteBuffer.wrap(b) + case _ => value + } + } + + override def merge(buffer: Sketch, input: Sketch): Sketch = { + new SetOperationBuilder().buildUnion.union(buffer, input) + } + + override def eval(buffer: Sketch): Any = { + toBytes(buffer) + } + + override def serialize(buffer: Sketch): Array[Byte] = { + toBytes(buffer) + } + + override def deserialize(storageFormat: Array[Byte]): Sketch = { + CompactSketch.wrap(Memory.wrap(storageFormat)) + } + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = { + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + } + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = { + copy(inputAggBufferOffset = newInputAggBufferOffset) + } + + override protected def withNewChildInternal(newChild: Expression): Expression = { + copy(child = newChild) + } + + private def toBytes(sketch: Sketch): Array[Byte] = { + val compactSketch = sketch.compact() + compactSketch.toByteArray + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/KryoHelpers.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/KryoHelpers.java new file mode 100644 index 000000000000..6d88aaa11813 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/KryoHelpers.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import org.apache.spark.SparkConf; +import org.apache.spark.serializer.KryoSerializer; + +public class KryoHelpers { + + private KryoHelpers() {} + + @SuppressWarnings("unchecked") + public static T roundTripSerialize(T obj) throws IOException { + Kryo kryo = new KryoSerializer(new SparkConf()).newKryo(); + + ByteArrayOutputStream bytes = new ByteArrayOutputStream(); + + try (Output out = new Output(new ObjectOutputStream(bytes))) { + kryo.writeClassAndObject(out, obj); + } + + try (Input in = + new Input(new ObjectInputStream(new ByteArrayInputStream(bytes.toByteArray())))) { + return (T) kryo.readClassAndObject(in); + } + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/SparkDistributedDataScanTestBase.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/SparkDistributedDataScanTestBase.java new file mode 100644 index 000000000000..404ba7284606 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/SparkDistributedDataScanTestBase.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg; + +import static org.apache.iceberg.PlanningMode.DISTRIBUTED; +import static org.apache.iceberg.PlanningMode.LOCAL; + +import java.util.Arrays; +import java.util.List; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.internal.SQLConf; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public abstract class SparkDistributedDataScanTestBase + extends DataTableScanTestBase> { + + @Parameters(name = "formatVersion = {0}, dataMode = {1}, deleteMode = {2}") + public static List parameters() { + return Arrays.asList( + new Object[] {1, LOCAL, LOCAL}, + new Object[] {1, LOCAL, DISTRIBUTED}, + new Object[] {1, DISTRIBUTED, LOCAL}, + new Object[] {1, DISTRIBUTED, DISTRIBUTED}, + new Object[] {2, LOCAL, LOCAL}, + new Object[] {2, LOCAL, DISTRIBUTED}, + new Object[] {2, DISTRIBUTED, LOCAL}, + new Object[] {2, DISTRIBUTED, DISTRIBUTED}); + } + + protected static SparkSession spark = null; + + @Parameter(index = 1) + private PlanningMode dataMode; + + @Parameter(index = 2) + private PlanningMode deleteMode; + + @BeforeEach + public void configurePlanningModes() { + table + .updateProperties() + .set(TableProperties.DATA_PLANNING_MODE, dataMode.modeName()) + .set(TableProperties.DELETE_PLANNING_MODE, deleteMode.modeName()) + .commit(); + } + + @Override + protected BatchScan useRef(BatchScan scan, String ref) { + return scan.useRef(ref); + } + + @Override + protected BatchScan useSnapshot(BatchScan scan, long snapshotId) { + return scan.useSnapshot(snapshotId); + } + + @Override + protected BatchScan asOfTime(BatchScan scan, long timestampMillis) { + return scan.asOfTime(timestampMillis); + } + + @Override + protected BatchScan newScan() { + SparkReadConf readConf = new SparkReadConf(spark, table, ImmutableMap.of()); + return new SparkDistributedDataScan(spark, table, readConf); + } + + protected static SparkSession initSpark(String serializer) { + return SparkSession.builder() + .master("local[2]") + .config("spark.serializer", serializer) + .config(SQLConf.SHUFFLE_PARTITIONS().key(), "4") + .getOrCreate(); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/TaskCheckHelper.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/TaskCheckHelper.java new file mode 100644 index 000000000000..bcd00eb6f4e5 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/TaskCheckHelper.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; + +public final class TaskCheckHelper { + private TaskCheckHelper() {} + + public static void assertEquals( + ScanTaskGroup expected, ScanTaskGroup actual) { + List expectedTasks = getFileScanTasksInFilePathOrder(expected); + List actualTasks = getFileScanTasksInFilePathOrder(actual); + + assertThat(actualTasks) + .as("The number of file scan tasks should match") + .hasSameSizeAs(expectedTasks); + + for (int i = 0; i < expectedTasks.size(); i++) { + FileScanTask expectedTask = expectedTasks.get(i); + FileScanTask actualTask = actualTasks.get(i); + assertEquals(expectedTask, actualTask); + } + } + + public static void assertEquals(FileScanTask expected, FileScanTask actual) { + assertEquals(expected.file(), actual.file()); + + // PartitionSpec implements its own equals method + assertThat(actual.spec()).as("PartitionSpec doesn't match").isEqualTo(expected.spec()); + + assertThat(actual.start()).as("starting position doesn't match").isEqualTo(expected.start()); + + assertThat(actual.start()) + .as("the number of bytes to scan doesn't match") + .isEqualTo(expected.start()); + + // simplify comparison on residual expression via comparing toString + assertThat(actual.residual().toString()) + .as("Residual expression doesn't match") + .isEqualTo(expected.residual().toString()); + } + + public static void assertEquals(DataFile expected, DataFile actual) { + assertThat(actual.location()) + .as("Should match the serialized record path") + .isEqualTo(expected.location()); + assertThat(actual.format()) + .as("Should match the serialized record format") + .isEqualTo(expected.format()); + assertThat(actual.partition().get(0, Object.class)) + .as("Should match the serialized record partition") + .isEqualTo(expected.partition().get(0, Object.class)); + assertThat(actual.recordCount()) + .as("Should match the serialized record count") + .isEqualTo(expected.recordCount()); + assertThat(actual.fileSizeInBytes()) + .as("Should match the serialized record size") + .isEqualTo(expected.fileSizeInBytes()); + assertThat(actual.valueCounts()) + .as("Should match the serialized record value counts") + .isEqualTo(expected.valueCounts()); + assertThat(actual.nullValueCounts()) + .as("Should match the serialized record null value counts") + .isEqualTo(expected.nullValueCounts()); + assertThat(actual.lowerBounds()) + .as("Should match the serialized record lower bounds") + .isEqualTo(expected.lowerBounds()); + assertThat(actual.upperBounds()) + .as("Should match the serialized record upper bounds") + .isEqualTo(expected.upperBounds()); + assertThat(actual.keyMetadata()) + .as("Should match the serialized record key metadata") + .isEqualTo(expected.keyMetadata()); + assertThat(actual.splitOffsets()) + .as("Should match the serialized record offsets") + .isEqualTo(expected.splitOffsets()); + assertThat(actual.keyMetadata()) + .as("Should match the serialized record offsets") + .isEqualTo(expected.keyMetadata()); + } + + private static List getFileScanTasksInFilePathOrder( + ScanTaskGroup taskGroup) { + return taskGroup.tasks().stream() + // use file path + start position to differentiate the tasks + .sorted(Comparator.comparing(o -> o.file().location() + "##" + o.start())) + .collect(Collectors.toList()); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestDataFileSerialization.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestDataFileSerialization.java new file mode 100644 index 000000000000..57c4dc7cdf23 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestDataFileSerialization.java @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg; + +import static org.apache.iceberg.TaskCheckHelper.assertEquals; +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.assertj.core.api.Assertions.assertThat; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.file.Path; +import java.util.Map; +import java.util.UUID; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.SparkParquetWriters; +import org.apache.iceberg.types.Types; +import org.apache.spark.SparkConf; +import org.apache.spark.serializer.KryoSerializer; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +public class TestDataFileSerialization { + + private static final Schema DATE_SCHEMA = + new Schema( + required(1, "id", Types.LongType.get()), + optional(2, "data", Types.StringType.get()), + required(3, "date", Types.StringType.get()), + optional(4, "double", Types.DoubleType.get())); + + private static final PartitionSpec PARTITION_SPEC = + PartitionSpec.builderFor(DATE_SCHEMA).identity("date").build(); + + private static final Map VALUE_COUNTS = Maps.newHashMap(); + private static final Map NULL_VALUE_COUNTS = Maps.newHashMap(); + private static final Map NAN_VALUE_COUNTS = Maps.newHashMap(); + private static final Map LOWER_BOUNDS = Maps.newHashMap(); + private static final Map UPPER_BOUNDS = Maps.newHashMap(); + + static { + VALUE_COUNTS.put(1, 5L); + VALUE_COUNTS.put(2, 3L); + VALUE_COUNTS.put(4, 2L); + NULL_VALUE_COUNTS.put(1, 0L); + NULL_VALUE_COUNTS.put(2, 2L); + NAN_VALUE_COUNTS.put(4, 1L); + LOWER_BOUNDS.put(1, longToBuffer(0L)); + UPPER_BOUNDS.put(1, longToBuffer(4L)); + } + + private static final DataFile DATA_FILE = + DataFiles.builder(PARTITION_SPEC) + .withPath("/path/to/data-1.parquet") + .withFileSizeInBytes(1234) + .withPartitionPath("date=2018-06-08") + .withMetrics( + new Metrics( + 5L, + null, + VALUE_COUNTS, + NULL_VALUE_COUNTS, + NAN_VALUE_COUNTS, + LOWER_BOUNDS, + UPPER_BOUNDS)) + .withSplitOffsets(ImmutableList.of(4L)) + .withEncryptionKeyMetadata(ByteBuffer.allocate(4).putInt(34)) + .withSortOrder(SortOrder.unsorted()) + .build(); + + @TempDir private Path temp; + + @Test + public void testDataFileKryoSerialization() throws Exception { + File data = File.createTempFile("junit", null, temp.toFile()); + assertThat(data.delete()).isTrue(); + Kryo kryo = new KryoSerializer(new SparkConf()).newKryo(); + + try (Output out = new Output(new FileOutputStream(data))) { + kryo.writeClassAndObject(out, DATA_FILE); + kryo.writeClassAndObject(out, DATA_FILE.copy()); + } + + try (Input in = new Input(new FileInputStream(data))) { + for (int i = 0; i < 2; i += 1) { + Object obj = kryo.readClassAndObject(in); + assertThat(obj).as("Should be a DataFile").isInstanceOf(DataFile.class); + assertEquals(DATA_FILE, (DataFile) obj); + } + } + } + + @Test + public void testDataFileJavaSerialization() throws Exception { + ByteArrayOutputStream bytes = new ByteArrayOutputStream(); + try (ObjectOutputStream out = new ObjectOutputStream(bytes)) { + out.writeObject(DATA_FILE); + out.writeObject(DATA_FILE.copy()); + } + + try (ObjectInputStream in = + new ObjectInputStream(new ByteArrayInputStream(bytes.toByteArray()))) { + for (int i = 0; i < 2; i += 1) { + Object obj = in.readObject(); + assertThat(obj).as("Should be a DataFile").isInstanceOf(DataFile.class); + assertEquals(DATA_FILE, (DataFile) obj); + } + } + } + + @Test + public void testParquetWriterSplitOffsets() throws IOException { + Iterable records = RandomData.generateSpark(DATE_SCHEMA, 1, 33L); + File parquetFile = + new File(temp.toFile(), FileFormat.PARQUET.addExtension(UUID.randomUUID().toString())); + FileAppender writer = + Parquet.write(Files.localOutput(parquetFile)) + .schema(DATE_SCHEMA) + .createWriterFunc( + msgType -> + SparkParquetWriters.buildWriter(SparkSchemaUtil.convert(DATE_SCHEMA), msgType)) + .build(); + try { + writer.addAll(records); + } finally { + writer.close(); + } + + Kryo kryo = new KryoSerializer(new SparkConf()).newKryo(); + File dataFile = File.createTempFile("junit", null, temp.toFile()); + try (Output out = new Output(new FileOutputStream(dataFile))) { + kryo.writeClassAndObject(out, writer.splitOffsets()); + } + try (Input in = new Input(new FileInputStream(dataFile))) { + kryo.readClassAndObject(in); + } + } + + private static ByteBuffer longToBuffer(long value) { + return ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN).putLong(0, value); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestFileIOSerialization.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestFileIOSerialization.java new file mode 100644 index 000000000000..bfdfa8deca06 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestFileIOSerialization.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.hadoop.HadoopFileIO; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.SerializableTableWithSize; +import org.apache.iceberg.types.Types; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +public class TestFileIOSerialization { + + private static final Configuration CONF = new Configuration(); + private static final HadoopTables TABLES = new HadoopTables(CONF); + + private static final Schema SCHEMA = + new Schema( + required(1, "id", Types.LongType.get()), + optional(2, "data", Types.StringType.get()), + required(3, "date", Types.StringType.get()), + optional(4, "double", Types.DoubleType.get())); + + private static final PartitionSpec SPEC = + PartitionSpec.builderFor(SCHEMA).identity("date").build(); + + private static final SortOrder SORT_ORDER = SortOrder.builderFor(SCHEMA).asc("id").build(); + + static { + CONF.set("k1", "v1"); + CONF.set("k2", "v2"); + } + + @TempDir private Path temp; + private Table table; + + @BeforeEach + public void initTable() throws IOException { + Map props = ImmutableMap.of("k1", "v1", "k2", "v2"); + + File tableLocation = Files.createTempDirectory(temp, "junit").toFile(); + assertThat(tableLocation.delete()).isTrue(); + + this.table = TABLES.create(SCHEMA, SPEC, SORT_ORDER, props, tableLocation.toString()); + } + + @Test + public void testHadoopFileIOKryoSerialization() throws IOException { + FileIO io = table.io(); + Configuration expectedConf = ((HadoopFileIO) io).conf(); + + Table serializableTable = SerializableTableWithSize.copyOf(table); + FileIO deserializedIO = KryoHelpers.roundTripSerialize(serializableTable.io()); + Configuration actualConf = ((HadoopFileIO) deserializedIO).conf(); + + assertThat(toMap(actualConf)).as("Conf pairs must match").isEqualTo(toMap(expectedConf)); + assertThat(actualConf.get("k1")).as("Conf values must be present").isEqualTo("v1"); + assertThat(actualConf.get("k2")).as("Conf values must be present").isEqualTo("v2"); + } + + @Test + public void testHadoopFileIOJavaSerialization() throws IOException, ClassNotFoundException { + FileIO io = table.io(); + Configuration expectedConf = ((HadoopFileIO) io).conf(); + + Table serializableTable = SerializableTableWithSize.copyOf(table); + FileIO deserializedIO = TestHelpers.roundTripSerialize(serializableTable.io()); + Configuration actualConf = ((HadoopFileIO) deserializedIO).conf(); + + assertThat(toMap(actualConf)).as("Conf pairs must match").isEqualTo(toMap(expectedConf)); + assertThat(actualConf.get("k1")).as("Conf values must be present").isEqualTo("v1"); + assertThat(actualConf.get("k2")).as("Conf values must be present").isEqualTo("v2"); + } + + private Map toMap(Configuration conf) { + Map map = Maps.newHashMapWithExpectedSize(conf.size()); + conf.forEach(entry -> map.put(entry.getKey(), entry.getValue())); + return map; + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestHadoopMetricsContextSerialization.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestHadoopMetricsContextSerialization.java new file mode 100644 index 000000000000..a4643d7a087b --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestHadoopMetricsContextSerialization.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg; + +import java.io.IOException; +import org.apache.iceberg.hadoop.HadoopMetricsContext; +import org.apache.iceberg.io.FileIOMetricsContext; +import org.apache.iceberg.metrics.MetricsContext; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.junit.jupiter.api.Test; + +public class TestHadoopMetricsContextSerialization { + + @Test + public void testHadoopMetricsContextKryoSerialization() throws IOException { + MetricsContext metricsContext = new HadoopMetricsContext("s3"); + + metricsContext.initialize(Maps.newHashMap()); + + MetricsContext deserializedMetricContext = KryoHelpers.roundTripSerialize(metricsContext); + // statistics are properly re-initialized post de-serialization + deserializedMetricContext + .counter(FileIOMetricsContext.WRITE_BYTES, MetricsContext.Unit.BYTES) + .increment(); + } + + @Test + public void testHadoopMetricsContextJavaSerialization() + throws IOException, ClassNotFoundException { + MetricsContext metricsContext = new HadoopMetricsContext("s3"); + + metricsContext.initialize(Maps.newHashMap()); + + MetricsContext deserializedMetricContext = TestHelpers.roundTripSerialize(metricsContext); + // statistics are properly re-initialized post de-serialization + deserializedMetricContext + .counter(FileIOMetricsContext.WRITE_BYTES, MetricsContext.Unit.BYTES) + .increment(); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestManifestFileSerialization.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestManifestFileSerialization.java new file mode 100644 index 000000000000..1e09917d0305 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestManifestFileSerialization.java @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.assertj.core.api.Assertions.assertThat; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.file.Path; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.ManifestFile.PartitionFieldSummary; +import org.apache.iceberg.hadoop.HadoopFileIO; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.types.Types; +import org.apache.spark.SparkConf; +import org.apache.spark.serializer.KryoSerializer; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +public class TestManifestFileSerialization { + + private static final Schema SCHEMA = + new Schema( + required(1, "id", Types.LongType.get()), + optional(2, "data", Types.StringType.get()), + required(3, "date", Types.StringType.get()), + required(4, "double", Types.DoubleType.get())); + + private static final PartitionSpec SPEC = + PartitionSpec.builderFor(SCHEMA).identity("double").build(); + + private static final DataFile FILE_A = + DataFiles.builder(SPEC) + .withPath("/path/to/data-1.parquet") + .withFileSizeInBytes(0) + .withPartition(TestHelpers.Row.of(1D)) + .withPartitionPath("double=1") + .withMetrics( + new Metrics( + 5L, + null, // no column sizes + ImmutableMap.of(1, 5L, 2, 3L), // value count + ImmutableMap.of(1, 0L, 2, 2L), // null count + ImmutableMap.of(), // nan count + ImmutableMap.of(1, longToBuffer(0L)), // lower bounds + ImmutableMap.of(1, longToBuffer(4L)) // upper bounds + )) + .build(); + + private static final DataFile FILE_B = + DataFiles.builder(SPEC) + .withPath("/path/to/data-2.parquet") + .withFileSizeInBytes(0) + .withPartition(TestHelpers.Row.of(Double.NaN)) + .withPartitionPath("double=NaN") + .withMetrics( + new Metrics( + 1L, + null, // no column sizes + ImmutableMap.of(1, 1L, 4, 1L), // value count + ImmutableMap.of(1, 0L, 2, 0L), // null count + ImmutableMap.of(4, 1L), // nan count + ImmutableMap.of(1, longToBuffer(0L)), // lower bounds + ImmutableMap.of(1, longToBuffer(1L)) // upper bounds + )) + .build(); + + private static final FileIO FILE_IO = new HadoopFileIO(new Configuration()); + + @TempDir private Path temp; + + @Test + public void testManifestFileKryoSerialization() throws IOException { + File data = File.createTempFile("junit", null, temp.toFile()); + assertThat(data.delete()).isTrue(); + + Kryo kryo = new KryoSerializer(new SparkConf()).newKryo(); + + ManifestFile manifest = writeManifest(FILE_A, FILE_B); + + try (Output out = new Output(new FileOutputStream(data))) { + kryo.writeClassAndObject(out, manifest); + kryo.writeClassAndObject(out, manifest.copy()); + kryo.writeClassAndObject(out, GenericManifestFile.copyOf(manifest).build()); + } + + try (Input in = new Input(new FileInputStream(data))) { + for (int i = 0; i < 3; i += 1) { + Object obj = kryo.readClassAndObject(in); + assertThat(obj).as("Should be a ManifestFile").isInstanceOf(ManifestFile.class); + checkManifestFile(manifest, (ManifestFile) obj); + } + } + } + + @Test + public void testManifestFileJavaSerialization() throws Exception { + ByteArrayOutputStream bytes = new ByteArrayOutputStream(); + + ManifestFile manifest = writeManifest(FILE_A, FILE_B); + + try (ObjectOutputStream out = new ObjectOutputStream(bytes)) { + out.writeObject(manifest); + out.writeObject(manifest.copy()); + out.writeObject(GenericManifestFile.copyOf(manifest).build()); + } + + try (ObjectInputStream in = + new ObjectInputStream(new ByteArrayInputStream(bytes.toByteArray()))) { + for (int i = 0; i < 3; i += 1) { + Object obj = in.readObject(); + assertThat(obj).as("Should be a ManifestFile").isInstanceOf(ManifestFile.class); + checkManifestFile(manifest, (ManifestFile) obj); + } + } + } + + private void checkManifestFile(ManifestFile expected, ManifestFile actual) { + assertThat(actual.path()).as("Path must match").isEqualTo(expected.path()); + assertThat(actual.length()).as("Length must match").isEqualTo(expected.length()); + assertThat(actual.partitionSpecId()) + .as("Spec id must match") + .isEqualTo(expected.partitionSpecId()); + assertThat(actual.snapshotId()).as("Snapshot id must match").isEqualTo(expected.snapshotId()); + assertThat(actual.hasAddedFiles()) + .as("Added files flag must match") + .isEqualTo(expected.hasAddedFiles()); + assertThat(actual.addedFilesCount()) + .as("Added files count must match") + .isEqualTo(expected.addedFilesCount()); + assertThat(actual.addedRowsCount()) + .as("Added rows count must match") + .isEqualTo(expected.addedRowsCount()); + assertThat(actual.hasExistingFiles()) + .as("Existing files flag must match") + .isEqualTo(expected.hasExistingFiles()); + assertThat(actual.existingFilesCount()) + .as("Existing files count must match") + .isEqualTo(expected.existingFilesCount()); + assertThat(actual.existingRowsCount()) + .as("Existing rows count must match") + .isEqualTo(expected.existingRowsCount()); + assertThat(actual.hasDeletedFiles()) + .as("Deleted files flag must match") + .isEqualTo(expected.hasDeletedFiles()); + assertThat(actual.deletedFilesCount()) + .as("Deleted files count must match") + .isEqualTo(expected.deletedFilesCount()); + assertThat(actual.deletedRowsCount()) + .as("Deleted rows count must match") + .isEqualTo(expected.deletedRowsCount()); + + PartitionFieldSummary expectedPartition = expected.partitions().get(0); + PartitionFieldSummary actualPartition = actual.partitions().get(0); + + assertThat(actualPartition.containsNull()) + .as("Null flag in partition must match") + .isEqualTo(expectedPartition.containsNull()); + assertThat(actualPartition.containsNaN()) + .as("NaN flag in partition must match") + .isEqualTo(expectedPartition.containsNaN()); + assertThat(actualPartition.lowerBound()) + .as("Lower bounds in partition must match") + .isEqualTo(expectedPartition.lowerBound()); + assertThat(actualPartition.upperBound()) + .as("Upper bounds in partition must match") + .isEqualTo(expectedPartition.upperBound()); + } + + private ManifestFile writeManifest(DataFile... files) throws IOException { + File manifestFile = File.createTempFile("input.m0", ".avro", temp.toFile()); + assertThat(manifestFile.delete()).isTrue(); + OutputFile outputFile = FILE_IO.newOutputFile(manifestFile.getCanonicalPath()); + + ManifestWriter writer = ManifestFiles.write(SPEC, outputFile); + try { + for (DataFile file : files) { + writer.add(file); + } + } finally { + writer.close(); + } + + return writer.toManifestFile(); + } + + private static ByteBuffer longToBuffer(long value) { + return ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN).putLong(0, value); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestScanTaskSerialization.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestScanTaskSerialization.java new file mode 100644 index 000000000000..4fdbc862ee8c --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestScanTaskSerialization.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.assertj.core.api.Assertions.assertThat; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.TestBase; +import org.apache.iceberg.spark.source.ThreeColumnRecord; +import org.apache.iceberg.types.Types; +import org.apache.spark.SparkConf; +import org.apache.spark.serializer.KryoSerializer; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +public class TestScanTaskSerialization extends TestBase { + + private static final HadoopTables TABLES = new HadoopTables(new Configuration()); + private static final Schema SCHEMA = + new Schema( + optional(1, "c1", Types.IntegerType.get()), + optional(2, "c2", Types.StringType.get()), + optional(3, "c3", Types.StringType.get())); + + @TempDir private Path temp; + @TempDir private File tableDir; + + private String tableLocation = null; + + @BeforeEach + public void setupTableLocation() throws Exception { + this.tableLocation = tableDir.toURI().toString(); + } + + @Test + public void testBaseCombinedScanTaskKryoSerialization() throws Exception { + BaseCombinedScanTask scanTask = prepareBaseCombinedScanTaskForSerDeTest(); + + File data = File.createTempFile("junit", null, temp.toFile()); + assertThat(data.delete()).isTrue(); + Kryo kryo = new KryoSerializer(new SparkConf()).newKryo(); + + try (Output out = new Output(new FileOutputStream(data))) { + kryo.writeClassAndObject(out, scanTask); + } + + try (Input in = new Input(new FileInputStream(data))) { + Object obj = kryo.readClassAndObject(in); + assertThat(obj) + .as("Should be a BaseCombinedScanTask") + .isInstanceOf(BaseCombinedScanTask.class); + TaskCheckHelper.assertEquals(scanTask, (BaseCombinedScanTask) obj); + } + } + + @Test + public void testBaseCombinedScanTaskJavaSerialization() throws Exception { + BaseCombinedScanTask scanTask = prepareBaseCombinedScanTaskForSerDeTest(); + + ByteArrayOutputStream bytes = new ByteArrayOutputStream(); + try (ObjectOutputStream out = new ObjectOutputStream(bytes)) { + out.writeObject(scanTask); + } + + try (ObjectInputStream in = + new ObjectInputStream(new ByteArrayInputStream(bytes.toByteArray()))) { + Object obj = in.readObject(); + assertThat(obj) + .as("Should be a BaseCombinedScanTask") + .isInstanceOf(BaseCombinedScanTask.class); + TaskCheckHelper.assertEquals(scanTask, (BaseCombinedScanTask) obj); + } + } + + @Test + @SuppressWarnings("unchecked") + public void testBaseScanTaskGroupKryoSerialization() throws Exception { + BaseScanTaskGroup taskGroup = prepareBaseScanTaskGroupForSerDeTest(); + + assertThat(taskGroup.tasks()).as("Task group can't be empty").isNotEmpty(); + + File data = File.createTempFile("junit", null, temp.toFile()); + assertThat(data.delete()).isTrue(); + Kryo kryo = new KryoSerializer(new SparkConf()).newKryo(); + + try (Output out = new Output(Files.newOutputStream(data.toPath()))) { + kryo.writeClassAndObject(out, taskGroup); + } + + try (Input in = new Input(Files.newInputStream(data.toPath()))) { + Object obj = kryo.readClassAndObject(in); + assertThat(obj).as("should be a BaseScanTaskGroup").isInstanceOf(BaseScanTaskGroup.class); + TaskCheckHelper.assertEquals(taskGroup, (BaseScanTaskGroup) obj); + } + } + + @Test + @SuppressWarnings("unchecked") + public void testBaseScanTaskGroupJavaSerialization() throws Exception { + BaseScanTaskGroup taskGroup = prepareBaseScanTaskGroupForSerDeTest(); + + assertThat(taskGroup.tasks()).as("Task group can't be empty").isNotEmpty(); + + ByteArrayOutputStream bytes = new ByteArrayOutputStream(); + try (ObjectOutputStream out = new ObjectOutputStream(bytes)) { + out.writeObject(taskGroup); + } + + try (ObjectInputStream in = + new ObjectInputStream(new ByteArrayInputStream(bytes.toByteArray()))) { + Object obj = in.readObject(); + assertThat(obj).as("should be a BaseScanTaskGroup").isInstanceOf(BaseScanTaskGroup.class); + TaskCheckHelper.assertEquals(taskGroup, (BaseScanTaskGroup) obj); + } + } + + private BaseCombinedScanTask prepareBaseCombinedScanTaskForSerDeTest() { + Table table = initTable(); + CloseableIterable tasks = table.newScan().planFiles(); + return new BaseCombinedScanTask(Lists.newArrayList(tasks)); + } + + private BaseScanTaskGroup prepareBaseScanTaskGroupForSerDeTest() { + Table table = initTable(); + CloseableIterable tasks = table.newScan().planFiles(); + return new BaseScanTaskGroup<>(ImmutableList.copyOf(tasks)); + } + + private Table initTable() { + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + List records1 = + Lists.newArrayList( + new ThreeColumnRecord(1, null, "AAAA"), new ThreeColumnRecord(1, "BBBBBBBBBB", "BBBB")); + writeRecords(records1); + + List records2 = + Lists.newArrayList( + new ThreeColumnRecord(2, "CCCCCCCCCC", "CCCC"), + new ThreeColumnRecord(2, "DDDDDDDDDD", "DDDD")); + writeRecords(records2); + + table.refresh(); + + return table; + } + + private void writeRecords(List records) { + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class); + writeDF(df); + } + + private void writeDF(Dataset df) { + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestSparkDistributedDataScanDeletes.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestSparkDistributedDataScanDeletes.java new file mode 100644 index 000000000000..659507e4c5e3 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestSparkDistributedDataScanDeletes.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg; + +import static org.apache.iceberg.PlanningMode.DISTRIBUTED; +import static org.apache.iceberg.PlanningMode.LOCAL; + +import java.util.Arrays; +import java.util.List; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.internal.SQLConf; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestSparkDistributedDataScanDeletes + extends DeleteFileIndexTestBase> { + + @Parameters(name = "formatVersion = {0}, dataMode = {1}, deleteMode = {2}") + public static List parameters() { + return Arrays.asList( + new Object[] {2, LOCAL, LOCAL}, + new Object[] {2, LOCAL, DISTRIBUTED}, + new Object[] {2, DISTRIBUTED, LOCAL}, + new Object[] {2, LOCAL, DISTRIBUTED}, + new Object[] {3, LOCAL, LOCAL}, + new Object[] {3, LOCAL, DISTRIBUTED}, + new Object[] {3, DISTRIBUTED, LOCAL}, + new Object[] {3, DISTRIBUTED, DISTRIBUTED}); + } + + private static SparkSession spark = null; + + @Parameter(index = 1) + private PlanningMode dataMode; + + @Parameter(index = 2) + private PlanningMode deleteMode; + + @BeforeEach + public void configurePlanningModes() { + table + .updateProperties() + .set(TableProperties.DATA_PLANNING_MODE, dataMode.modeName()) + .set(TableProperties.DELETE_PLANNING_MODE, deleteMode.modeName()) + .commit(); + } + + @BeforeAll + public static void startSpark() { + TestSparkDistributedDataScanDeletes.spark = + SparkSession.builder() + .master("local[2]") + .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .config(SQLConf.SHUFFLE_PARTITIONS().key(), "4") + .getOrCreate(); + } + + @AfterAll + public static void stopSpark() { + SparkSession currentSpark = TestSparkDistributedDataScanDeletes.spark; + TestSparkDistributedDataScanDeletes.spark = null; + currentSpark.stop(); + } + + @Override + protected BatchScan newScan(Table table) { + SparkReadConf readConf = new SparkReadConf(spark, table, ImmutableMap.of()); + return new SparkDistributedDataScan(spark, table, readConf); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestSparkDistributedDataScanFilterFiles.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestSparkDistributedDataScanFilterFiles.java new file mode 100644 index 000000000000..a218f965ea65 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestSparkDistributedDataScanFilterFiles.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg; + +import static org.apache.iceberg.PlanningMode.DISTRIBUTED; +import static org.apache.iceberg.PlanningMode.LOCAL; + +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.internal.SQLConf; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestSparkDistributedDataScanFilterFiles + extends FilterFilesTestBase> { + + @Parameters(name = "formatVersion = {0}, dataMode = {1}, deleteMode = {2}") + public static Object[] parameters() { + return new Object[][] { + new Object[] {1, LOCAL, LOCAL}, + new Object[] {1, LOCAL, DISTRIBUTED}, + new Object[] {1, DISTRIBUTED, LOCAL}, + new Object[] {1, DISTRIBUTED, DISTRIBUTED}, + new Object[] {2, LOCAL, LOCAL}, + new Object[] {2, LOCAL, DISTRIBUTED}, + new Object[] {2, DISTRIBUTED, LOCAL}, + new Object[] {2, DISTRIBUTED, DISTRIBUTED} + }; + } + + private static SparkSession spark = null; + + @Parameter(index = 1) + private PlanningMode dataMode; + + @Parameter(index = 2) + private PlanningMode deleteMode; + + @BeforeAll + public static void startSpark() { + TestSparkDistributedDataScanFilterFiles.spark = + SparkSession.builder() + .master("local[2]") + .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .config(SQLConf.SHUFFLE_PARTITIONS().key(), "4") + .getOrCreate(); + } + + @AfterAll + public static void stopSpark() { + SparkSession currentSpark = TestSparkDistributedDataScanFilterFiles.spark; + TestSparkDistributedDataScanFilterFiles.spark = null; + currentSpark.stop(); + } + + @Override + protected BatchScan newScan(Table table) { + table + .updateProperties() + .set(TableProperties.DATA_PLANNING_MODE, dataMode.modeName()) + .set(TableProperties.DELETE_PLANNING_MODE, deleteMode.modeName()) + .commit(); + SparkReadConf readConf = new SparkReadConf(spark, table, ImmutableMap.of()); + return new SparkDistributedDataScan(spark, table, readConf); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestSparkDistributedDataScanJavaSerialization.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestSparkDistributedDataScanJavaSerialization.java new file mode 100644 index 000000000000..b8bd6fb86747 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestSparkDistributedDataScanJavaSerialization.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg; + +import org.apache.spark.sql.SparkSession; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; + +public class TestSparkDistributedDataScanJavaSerialization + extends SparkDistributedDataScanTestBase { + + @BeforeAll + public static void startSpark() { + SparkDistributedDataScanTestBase.spark = + initSpark("org.apache.spark.serializer.JavaSerializer"); + } + + @AfterAll + public static void stopSpark() { + SparkSession currentSpark = SparkDistributedDataScanTestBase.spark; + SparkDistributedDataScanTestBase.spark = null; + currentSpark.stop(); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestSparkDistributedDataScanKryoSerialization.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestSparkDistributedDataScanKryoSerialization.java new file mode 100644 index 000000000000..08d66cccb627 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestSparkDistributedDataScanKryoSerialization.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg; + +import org.apache.spark.sql.SparkSession; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; + +public class TestSparkDistributedDataScanKryoSerialization + extends SparkDistributedDataScanTestBase { + + @BeforeAll + public static void startSpark() { + SparkDistributedDataScanTestBase.spark = + initSpark("org.apache.spark.serializer.KryoSerializer"); + } + + @AfterAll + public static void stopSpark() { + SparkSession currentSpark = SparkDistributedDataScanTestBase.spark; + SparkDistributedDataScanTestBase.spark = null; + currentSpark.stop(); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestSparkDistributedDataScanReporting.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestSparkDistributedDataScanReporting.java new file mode 100644 index 000000000000..2665d7ba8d3b --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestSparkDistributedDataScanReporting.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg; + +import static org.apache.iceberg.PlanningMode.DISTRIBUTED; +import static org.apache.iceberg.PlanningMode.LOCAL; + +import java.util.Arrays; +import java.util.List; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.internal.SQLConf; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestSparkDistributedDataScanReporting + extends ScanPlanningAndReportingTestBase> { + + @Parameters(name = "formatVersion = {0}, dataMode = {1}, deleteMode = {2}") + public static List parameters() { + return Arrays.asList( + new Object[] {2, LOCAL, LOCAL}, + new Object[] {2, LOCAL, DISTRIBUTED}, + new Object[] {2, DISTRIBUTED, LOCAL}, + new Object[] {2, DISTRIBUTED, DISTRIBUTED}, + new Object[] {3, LOCAL, LOCAL}, + new Object[] {3, LOCAL, DISTRIBUTED}, + new Object[] {3, DISTRIBUTED, LOCAL}, + new Object[] {3, DISTRIBUTED, DISTRIBUTED}); + } + + private static SparkSession spark = null; + + @Parameter(index = 1) + private PlanningMode dataMode; + + @Parameter(index = 2) + private PlanningMode deleteMode; + + @BeforeAll + public static void startSpark() { + TestSparkDistributedDataScanReporting.spark = + SparkSession.builder() + .master("local[2]") + .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .config(SQLConf.SHUFFLE_PARTITIONS().key(), "4") + .getOrCreate(); + } + + @AfterAll + public static void stopSpark() { + SparkSession currentSpark = TestSparkDistributedDataScanReporting.spark; + TestSparkDistributedDataScanReporting.spark = null; + currentSpark.stop(); + } + + @Override + protected BatchScan newScan(Table table) { + table + .updateProperties() + .set(TableProperties.DATA_PLANNING_MODE, dataMode.modeName()) + .set(TableProperties.DELETE_PLANNING_MODE, deleteMode.modeName()) + .commit(); + SparkReadConf readConf = new SparkReadConf(spark, table, ImmutableMap.of()); + return new SparkDistributedDataScan(spark, table, readConf); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestTableSerialization.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestTableSerialization.java new file mode 100644 index 000000000000..fd6dfd07b568 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/TestTableSerialization.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.source.SerializableTableWithSize; +import org.apache.iceberg.types.Types; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestTableSerialization { + + @Parameters(name = "isObjectStoreEnabled = {0}") + public static List parameters() { + return Arrays.asList("true", "false"); + } + + private static final HadoopTables TABLES = new HadoopTables(); + + @Parameter private String isObjectStoreEnabled; + + private static final Schema SCHEMA = + new Schema( + required(1, "id", Types.LongType.get()), + optional(2, "data", Types.StringType.get()), + required(3, "date", Types.StringType.get()), + optional(4, "double", Types.DoubleType.get())); + + private static final PartitionSpec SPEC = + PartitionSpec.builderFor(SCHEMA).identity("date").build(); + + private static final SortOrder SORT_ORDER = SortOrder.builderFor(SCHEMA).asc("id").build(); + + @TempDir private Path temp; + private Table table; + + @BeforeEach + public void initTable() throws IOException { + Map props = + ImmutableMap.of("k1", "v1", TableProperties.OBJECT_STORE_ENABLED, isObjectStoreEnabled); + + File tableLocation = Files.createTempDirectory(temp, "junit").toFile(); + assertThat(tableLocation.delete()).isTrue(); + + this.table = TABLES.create(SCHEMA, SPEC, SORT_ORDER, props, tableLocation.toString()); + } + + @TestTemplate + public void testCloseSerializableTableKryoSerialization() throws Exception { + for (Table tbl : tables()) { + Table spyTable = spy(tbl); + FileIO spyIO = spy(tbl.io()); + when(spyTable.io()).thenReturn(spyIO); + + Table serializableTable = SerializableTableWithSize.copyOf(spyTable); + + Table serializableTableCopy = spy(KryoHelpers.roundTripSerialize(serializableTable)); + FileIO spyFileIOCopy = spy(serializableTableCopy.io()); + when(serializableTableCopy.io()).thenReturn(spyFileIOCopy); + + ((AutoCloseable) serializableTable).close(); // mimics close on the driver + ((AutoCloseable) serializableTableCopy).close(); // mimics close on executors + + verify(spyIO, never()).close(); + verify(spyFileIOCopy, times(1)).close(); + } + } + + @TestTemplate + public void testCloseSerializableTableJavaSerialization() throws Exception { + for (Table tbl : tables()) { + Table spyTable = spy(tbl); + FileIO spyIO = spy(tbl.io()); + when(spyTable.io()).thenReturn(spyIO); + + Table serializableTable = SerializableTableWithSize.copyOf(spyTable); + + Table serializableTableCopy = spy(TestHelpers.roundTripSerialize(serializableTable)); + FileIO spyFileIOCopy = spy(serializableTableCopy.io()); + when(serializableTableCopy.io()).thenReturn(spyFileIOCopy); + + ((AutoCloseable) serializableTable).close(); // mimics close on the driver + ((AutoCloseable) serializableTableCopy).close(); // mimics close on executors + + verify(spyIO, never()).close(); + verify(spyFileIOCopy, times(1)).close(); + } + } + + @TestTemplate + public void testSerializableTableKryoSerialization() throws IOException { + Table serializableTable = SerializableTableWithSize.copyOf(table); + TestHelpers.assertSerializedAndLoadedMetadata( + table, KryoHelpers.roundTripSerialize(serializableTable)); + } + + @TestTemplate + public void testSerializableMetadataTableKryoSerialization() throws IOException { + for (MetadataTableType type : MetadataTableType.values()) { + TableOperations ops = ((HasTableOperations) table).operations(); + Table metadataTable = + MetadataTableUtils.createMetadataTableInstance(ops, table.name(), "meta", type); + Table serializableMetadataTable = SerializableTableWithSize.copyOf(metadataTable); + + TestHelpers.assertSerializedAndLoadedMetadata( + metadataTable, KryoHelpers.roundTripSerialize(serializableMetadataTable)); + } + } + + @TestTemplate + public void testSerializableTransactionTableKryoSerialization() throws IOException { + Transaction txn = table.newTransaction(); + + txn.updateProperties().set("k1", "v1").commit(); + + Table txnTable = txn.table(); + Table serializableTxnTable = SerializableTableWithSize.copyOf(txnTable); + + TestHelpers.assertSerializedMetadata( + txnTable, KryoHelpers.roundTripSerialize(serializableTxnTable)); + } + + private List

tables() { + List
tables = Lists.newArrayList(); + + tables.add(table); + + for (MetadataTableType type : MetadataTableType.values()) { + Table metadataTable = MetadataTableUtils.createMetadataTableInstance(table, type); + tables.add(metadataTable); + } + + return tables; + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/ValidationHelpers.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/ValidationHelpers.java new file mode 100644 index 000000000000..7314043b15e2 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/ValidationHelpers.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; + +public class ValidationHelpers { + + private ValidationHelpers() {} + + public static List dataSeqs(Long... seqs) { + return Arrays.asList(seqs); + } + + public static List fileSeqs(Long... seqs) { + return Arrays.asList(seqs); + } + + public static List snapshotIds(Long... ids) { + return Arrays.asList(ids); + } + + public static List files(ContentFile... files) { + return Arrays.stream(files).map(file -> file.location()).collect(Collectors.toList()); + } + + public static void validateDataManifest( + Table table, + ManifestFile manifest, + List dataSeqs, + List fileSeqs, + List snapshotIds, + List files) { + + List actualDataSeqs = Lists.newArrayList(); + List actualFileSeqs = Lists.newArrayList(); + List actualSnapshotIds = Lists.newArrayList(); + List actualFiles = Lists.newArrayList(); + + for (ManifestEntry entry : ManifestFiles.read(manifest, table.io()).entries()) { + actualDataSeqs.add(entry.dataSequenceNumber()); + actualFileSeqs.add(entry.fileSequenceNumber()); + actualSnapshotIds.add(entry.snapshotId()); + actualFiles.add(entry.file().location()); + } + + assertSameElements("data seqs", actualDataSeqs, dataSeqs); + assertSameElements("file seqs", actualFileSeqs, fileSeqs); + assertSameElements("snapshot IDs", actualSnapshotIds, snapshotIds); + assertSameElements("files", actualFiles, files); + } + + private static void assertSameElements(String context, List actual, List expected) { + String errorMessage = String.format("%s must match", context); + assertThat(actual).as(errorMessage).hasSameElementsAs(expected); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/CatalogTestBase.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/CatalogTestBase.java new file mode 100644 index 000000000000..ba864bf89e33 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/CatalogTestBase.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public abstract class CatalogTestBase extends TestBaseWithCatalog { + + // these parameters are broken out to avoid changes that need to modify lots of test suites + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + protected static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.HIVE.catalogName(), + SparkCatalogConfig.HIVE.implementation(), + SparkCatalogConfig.HIVE.properties() + }, + { + SparkCatalogConfig.HADOOP.catalogName(), + SparkCatalogConfig.HADOOP.implementation(), + SparkCatalogConfig.HADOOP.properties() + }, + { + SparkCatalogConfig.SPARK.catalogName(), + SparkCatalogConfig.SPARK.implementation(), + SparkCatalogConfig.SPARK.properties() + } + }; + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/Employee.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/Employee.java new file mode 100644 index 000000000000..9c57936d989e --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/Employee.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Objects; + +public class Employee { + private Integer id; + private String dep; + + public Employee() {} + + public Employee(Integer id, String dep) { + this.id = id; + this.dep = dep; + } + + public Integer getId() { + return id; + } + + public void setId(Integer id) { + this.id = id; + } + + public String getDep() { + return dep; + } + + public void setDep(String dep) { + this.dep = dep; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } else if (other == null || getClass() != other.getClass()) { + return false; + } + + Employee employee = (Employee) other; + return Objects.equals(id, employee.id) && Objects.equals(dep, employee.dep); + } + + @Override + public int hashCode() { + return Objects.hash(id, dep); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/SparkCatalogConfig.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/SparkCatalogConfig.java new file mode 100644 index 000000000000..abfd7da0c7bd --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/SparkCatalogConfig.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Map; +import org.apache.iceberg.CatalogProperties; +import org.apache.iceberg.inmemory.InMemoryCatalog; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; + +public enum SparkCatalogConfig { + HIVE( + "testhive", + SparkCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default")), + HADOOP( + "testhadoop", + SparkCatalog.class.getName(), + ImmutableMap.of("type", "hadoop", "cache-enabled", "false")), + SPARK( + "spark_catalog", + SparkSessionCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "parquet-enabled", "true", + "cache-enabled", + "false" // Spark will delete tables using v1, leaving the cache out of sync + )), + SPARK_WITH_VIEWS( + "spark_with_views", + SparkCatalog.class.getName(), + ImmutableMap.of( + CatalogProperties.CATALOG_IMPL, + InMemoryCatalog.class.getName(), + "default-namespace", + "default", + "cache-enabled", + "false")); + + private final String catalogName; + private final String implementation; + private final Map properties; + + SparkCatalogConfig(String catalogName, String implementation, Map properties) { + this.catalogName = catalogName; + this.implementation = implementation; + this.properties = properties; + } + + public String catalogName() { + return catalogName; + } + + public String implementation() { + return implementation; + } + + public Map properties() { + return properties; + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/SparkTestHelperBase.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/SparkTestHelperBase.java new file mode 100644 index 000000000000..0b3d0244a087 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/SparkTestHelperBase.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.apache.spark.sql.Row; + +public class SparkTestHelperBase { + protected static final Object ANY = new Object(); + + protected List rowsToJava(List rows) { + return rows.stream().map(this::toJava).collect(Collectors.toList()); + } + + private Object[] toJava(Row row) { + return IntStream.range(0, row.size()) + .mapToObj( + pos -> { + if (row.isNullAt(pos)) { + return null; + } + + Object value = row.get(pos); + if (value instanceof Row) { + return toJava((Row) value); + } else if (value instanceof scala.collection.Seq) { + return row.getList(pos); + } else if (value instanceof scala.collection.Map) { + return row.getJavaMap(pos); + } else { + return value; + } + }) + .toArray(Object[]::new); + } + + protected void assertEquals( + String context, List expectedRows, List actualRows) { + assertThat(actualRows) + .as(context + ": number of results should match") + .hasSameSizeAs(expectedRows); + for (int row = 0; row < expectedRows.size(); row += 1) { + Object[] expected = expectedRows.get(row); + Object[] actual = actualRows.get(row); + assertThat(actual).as("Number of columns should match").hasSameSizeAs(expected); + for (int col = 0; col < actualRows.get(row).length; col += 1) { + String newContext = String.format("%s: row %d col %d", context, row + 1, col + 1); + assertEquals(newContext, expected, actual); + } + } + } + + protected void assertEquals(String context, Object[] expectedRow, Object[] actualRow) { + assertThat(actualRow).as("Number of columns should match").hasSameSizeAs(expectedRow); + for (int col = 0; col < actualRow.length; col += 1) { + Object expectedValue = expectedRow[col]; + Object actualValue = actualRow[col]; + if (expectedValue != null && expectedValue.getClass().isArray()) { + String newContext = String.format("%s (nested col %d)", context, col + 1); + if (expectedValue instanceof byte[]) { + assertThat(actualValue).as(newContext).isEqualTo(expectedValue); + } else { + assertEquals(newContext, (Object[]) expectedValue, (Object[]) actualValue); + } + } else if (expectedValue != ANY) { + assertThat(actualValue).as(context + " contents should match").isEqualTo(expectedValue); + } + } + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/SystemFunctionPushDownHelper.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/SystemFunctionPushDownHelper.java new file mode 100644 index 000000000000..059325e02a34 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/SystemFunctionPushDownHelper.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.DateTimeUtil; +import org.apache.spark.sql.SparkSession; + +public class SystemFunctionPushDownHelper { + public static final Types.StructType STRUCT = + Types.StructType.of( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "ts", Types.TimestampType.withZone()), + Types.NestedField.optional(3, "data", Types.StringType.get())); + + private SystemFunctionPushDownHelper() {} + + public static void createUnpartitionedTable(SparkSession spark, String tableName) { + sql(spark, "CREATE TABLE %s (id BIGINT, ts TIMESTAMP, data STRING) USING iceberg", tableName); + insertRecords(spark, tableName); + } + + public static void createPartitionedTable( + SparkSession spark, String tableName, String partitionCol) { + sql( + spark, + "CREATE TABLE %s (id BIGINT, ts TIMESTAMP, data STRING) USING iceberg PARTITIONED BY (%s)", + tableName, + partitionCol); + insertRecords(spark, tableName); + } + + private static void insertRecords(SparkSession spark, String tableName) { + sql( + spark, + "ALTER TABLE %s SET TBLPROPERTIES('%s' %s)", + tableName, + "read.split.target-size", + "10"); + sql( + spark, + "INSERT INTO TABLE %s VALUES %s", + tableName, + "(0, CAST('2017-11-22T09:20:44.294658+00:00' AS TIMESTAMP), 'data-0')"); + sql( + spark, + "INSERT INTO TABLE %s VALUES %s", + tableName, + "(1, CAST('2017-11-22T07:15:34.582910+00:00' AS TIMESTAMP), 'data-1')"); + sql( + spark, + "INSERT INTO TABLE %s VALUES %s", + tableName, + "(2, CAST('2017-11-22T06:02:09.243857+00:00' AS TIMESTAMP), 'data-2')"); + sql( + spark, + "INSERT INTO TABLE %s VALUES %s", + tableName, + "(3, CAST('2017-11-22T03:10:11.134509+00:00' AS TIMESTAMP), 'data-3')"); + sql( + spark, + "INSERT INTO TABLE %s VALUES %s", + tableName, + "(4, CAST('2017-11-22T00:34:00.184671+00:00' AS TIMESTAMP), 'data-4')"); + sql( + spark, + "INSERT INTO TABLE %s VALUES %s", + tableName, + "(5, CAST('2018-12-21T22:20:08.935889+00:00' AS TIMESTAMP), 'material-5')"); + sql( + spark, + "INSERT INTO TABLE %s VALUES %s", + tableName, + "(6, CAST('2018-12-21T21:55:30.589712+00:00' AS TIMESTAMP), 'material-6')"); + sql( + spark, + "INSERT INTO TABLE %s VALUES %s", + tableName, + "(7, CAST('2018-12-21T17:31:14.532797+00:00' AS TIMESTAMP), 'material-7')"); + sql( + spark, + "INSERT INTO TABLE %s VALUES %s", + tableName, + "(8, CAST('2018-12-21T15:21:51.237521+00:00' AS TIMESTAMP), 'material-8')"); + sql( + spark, + "INSERT INTO TABLE %s VALUES %s", + tableName, + "(9, CAST('2018-12-21T15:02:15.230570+00:00' AS TIMESTAMP), 'material-9')"); + } + + public static int timestampStrToYearOrdinal(String timestamp) { + return DateTimeUtil.microsToYears(DateTimeUtil.isoTimestamptzToMicros(timestamp)); + } + + public static int timestampStrToMonthOrdinal(String timestamp) { + return DateTimeUtil.microsToMonths(DateTimeUtil.isoTimestamptzToMicros(timestamp)); + } + + public static int timestampStrToDayOrdinal(String timestamp) { + return DateTimeUtil.microsToDays(DateTimeUtil.isoTimestamptzToMicros(timestamp)); + } + + public static int timestampStrToHourOrdinal(String timestamp) { + return DateTimeUtil.microsToHours(DateTimeUtil.isoTimestamptzToMicros(timestamp)); + } + + private static void sql(SparkSession spark, String sqlFormat, Object... args) { + spark.sql(String.format(sqlFormat, args)); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestBase.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestBase.java new file mode 100644 index 000000000000..86afd2f890ae --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestBase.java @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.METASTOREURIS; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.net.URI; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.List; +import java.util.Map; +import java.util.TimeZone; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReference; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.iceberg.CatalogUtil; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.exceptions.AlreadyExistsException; +import org.apache.iceberg.hive.HiveCatalog; +import org.apache.iceberg.hive.TestHiveMetastore; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.execution.QueryExecution; +import org.apache.spark.sql.execution.SparkPlan; +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.util.QueryExecutionListener; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; + +public abstract class TestBase extends SparkTestHelperBase { + + protected static TestHiveMetastore metastore = null; + protected static HiveConf hiveConf = null; + protected static SparkSession spark = null; + protected static JavaSparkContext sparkContext = null; + protected static HiveCatalog catalog = null; + + @BeforeAll + public static void startMetastoreAndSpark() { + TestBase.metastore = new TestHiveMetastore(); + metastore.start(); + TestBase.hiveConf = metastore.hiveConf(); + + TestBase.spark = + SparkSession.builder() + .master("local[2]") + .config(SQLConf.PARTITION_OVERWRITE_MODE().key(), "dynamic") + .config("spark.hadoop." + METASTOREURIS.varname, hiveConf.get(METASTOREURIS.varname)) + .config("spark.sql.legacy.respectNullabilityInTextDatasetConversion", "true") + .enableHiveSupport() + .getOrCreate(); + + TestBase.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + + TestBase.catalog = + (HiveCatalog) + CatalogUtil.loadCatalog( + HiveCatalog.class.getName(), "hive", ImmutableMap.of(), hiveConf); + + try { + catalog.createNamespace(Namespace.of("default")); + } catch (AlreadyExistsException ignored) { + // the default namespace already exists. ignore the create error + } + } + + @AfterAll + public static void stopMetastoreAndSpark() throws Exception { + TestBase.catalog = null; + if (metastore != null) { + metastore.stop(); + TestBase.metastore = null; + } + if (spark != null) { + spark.stop(); + TestBase.spark = null; + TestBase.sparkContext = null; + } + } + + protected long waitUntilAfter(long timestampMillis) { + long current = System.currentTimeMillis(); + while (current <= timestampMillis) { + current = System.currentTimeMillis(); + } + return current; + } + + protected List sql(String query, Object... args) { + List rows = spark.sql(String.format(query, args)).collectAsList(); + if (rows.isEmpty()) { + return ImmutableList.of(); + } + + return rowsToJava(rows); + } + + protected Object scalarSql(String query, Object... args) { + List rows = sql(query, args); + assertThat(rows.size()).as("Scalar SQL should return one row").isEqualTo(1); + Object[] row = Iterables.getOnlyElement(rows); + assertThat(row.length).as("Scalar SQL should return one value").isEqualTo(1); + return row[0]; + } + + protected Object[] row(Object... values) { + return values; + } + + protected static String dbPath(String dbName) { + return metastore.getDatabasePath(dbName); + } + + protected void withUnavailableFiles(Iterable> files, Action action) { + Iterable fileLocations = Iterables.transform(files, file -> file.location()); + withUnavailableLocations(fileLocations, action); + } + + private void move(String location, String newLocation) { + Path path = Paths.get(URI.create(location)); + Path tempPath = Paths.get(URI.create(newLocation)); + + try { + Files.move(path, tempPath); + } catch (IOException e) { + throw new UncheckedIOException("Failed to move: " + location, e); + } + } + + protected void withUnavailableLocations(Iterable locations, Action action) { + for (String location : locations) { + move(location, location + "_temp"); + } + + try { + action.invoke(); + } finally { + for (String location : locations) { + move(location + "_temp", location); + } + } + } + + protected void withDefaultTimeZone(String zoneId, Action action) { + TimeZone currentZone = TimeZone.getDefault(); + try { + TimeZone.setDefault(TimeZone.getTimeZone(zoneId)); + action.invoke(); + } finally { + TimeZone.setDefault(currentZone); + } + } + + protected void withSQLConf(Map conf, Action action) { + SQLConf sqlConf = SQLConf.get(); + + Map currentConfValues = Maps.newHashMap(); + conf.keySet() + .forEach( + confKey -> { + if (sqlConf.contains(confKey)) { + String currentConfValue = sqlConf.getConfString(confKey); + currentConfValues.put(confKey, currentConfValue); + } + }); + + conf.forEach( + (confKey, confValue) -> { + if (SQLConf.isStaticConfigKey(confKey)) { + throw new RuntimeException("Cannot modify the value of a static config: " + confKey); + } + sqlConf.setConfString(confKey, confValue); + }); + + try { + action.invoke(); + } finally { + conf.forEach( + (confKey, confValue) -> { + if (currentConfValues.containsKey(confKey)) { + sqlConf.setConfString(confKey, currentConfValues.get(confKey)); + } else { + sqlConf.unsetConf(confKey); + } + }); + } + } + + protected Dataset jsonToDF(String schema, String... records) { + Dataset jsonDF = spark.createDataset(ImmutableList.copyOf(records), Encoders.STRING()); + return spark.read().schema(schema).json(jsonDF); + } + + protected void append(String table, String... jsonRecords) { + try { + String schema = spark.table(table).schema().toDDL(); + Dataset df = jsonToDF(schema, jsonRecords); + df.coalesce(1).writeTo(table).append(); + } catch (NoSuchTableException e) { + throw new RuntimeException("Failed to write data", e); + } + } + + protected String tablePropsAsString(Map tableProps) { + StringBuilder stringBuilder = new StringBuilder(); + + for (Map.Entry property : tableProps.entrySet()) { + if (stringBuilder.length() > 0) { + stringBuilder.append(", "); + } + stringBuilder.append(String.format("'%s' '%s'", property.getKey(), property.getValue())); + } + + return stringBuilder.toString(); + } + + protected SparkPlan executeAndKeepPlan(String query, Object... args) { + return executeAndKeepPlan(() -> sql(query, args)); + } + + protected SparkPlan executeAndKeepPlan(Action action) { + AtomicReference executedPlanRef = new AtomicReference<>(); + + QueryExecutionListener listener = + new QueryExecutionListener() { + @Override + public void onSuccess(String funcName, QueryExecution qe, long durationNs) { + executedPlanRef.set(qe.executedPlan()); + } + + @Override + public void onFailure(String funcName, QueryExecution qe, Exception exception) {} + }; + + spark.listenerManager().register(listener); + + action.invoke(); + + try { + spark.sparkContext().listenerBus().waitUntilEmpty(); + } catch (TimeoutException e) { + throw new RuntimeException("Timeout while waiting for processing events", e); + } + + SparkPlan executedPlan = executedPlanRef.get(); + if (executedPlan instanceof AdaptiveSparkPlanExec) { + return ((AdaptiveSparkPlanExec) executedPlan).executedPlan(); + } else { + return executedPlan; + } + } + + @FunctionalInterface + protected interface Action { + void invoke(); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestBaseWithCatalog.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestBaseWithCatalog.java new file mode 100644 index 000000000000..c869c4a30a19 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestBaseWithCatalog.java @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.File; +import java.io.IOException; +import java.util.Map; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.CatalogProperties; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.PlanningMode; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.catalog.Catalog; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.SupportsNamespaces; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.hadoop.HadoopCatalog; +import org.apache.iceberg.util.PropertyUtil; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; + +@ExtendWith(ParameterizedTestExtension.class) +public abstract class TestBaseWithCatalog extends TestBase { + protected static File warehouse = null; + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + protected static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.HADOOP.catalogName(), + SparkCatalogConfig.HADOOP.implementation(), + SparkCatalogConfig.HADOOP.properties() + }, + }; + } + + @BeforeAll + public static void createWarehouse() throws IOException { + TestBaseWithCatalog.warehouse = File.createTempFile("warehouse", null); + assertThat(warehouse.delete()).isTrue(); + } + + @AfterAll + public static void dropWarehouse() throws IOException { + if (warehouse != null && warehouse.exists()) { + Path warehousePath = new Path(warehouse.getAbsolutePath()); + FileSystem fs = warehousePath.getFileSystem(hiveConf); + assertThat(fs.delete(warehousePath, true)).as("Failed to delete " + warehousePath).isTrue(); + } + } + + @TempDir protected java.nio.file.Path temp; + + @Parameter(index = 0) + protected String catalogName; + + @Parameter(index = 1) + protected String implementation; + + @Parameter(index = 2) + protected Map catalogConfig; + + protected Catalog validationCatalog; + protected SupportsNamespaces validationNamespaceCatalog; + protected TableIdentifier tableIdent = TableIdentifier.of(Namespace.of("default"), "table"); + protected String tableName; + + @BeforeEach + public void before() { + this.validationCatalog = + catalogName.equals("testhadoop") + ? new HadoopCatalog(spark.sessionState().newHadoopConf(), "file:" + warehouse) + : catalog; + this.validationNamespaceCatalog = (SupportsNamespaces) validationCatalog; + + spark.conf().set("spark.sql.catalog." + catalogName, implementation); + catalogConfig.forEach( + (key, value) -> spark.conf().set("spark.sql.catalog." + catalogName + "." + key, value)); + + if ("hadoop".equalsIgnoreCase(catalogConfig.get("type"))) { + spark.conf().set("spark.sql.catalog." + catalogName + ".warehouse", "file:" + warehouse); + } + + this.tableName = + (catalogName.equals("spark_catalog") ? "" : catalogName + ".") + "default.table"; + + sql("CREATE NAMESPACE IF NOT EXISTS default"); + } + + protected String tableName(String name) { + return (catalogName.equals("spark_catalog") ? "" : catalogName + ".") + "default." + name; + } + + protected String commitTarget() { + return tableName; + } + + protected String selectTarget() { + return tableName; + } + + protected boolean cachingCatalogEnabled() { + return PropertyUtil.propertyAsBoolean( + catalogConfig, CatalogProperties.CACHE_ENABLED, CatalogProperties.CACHE_ENABLED_DEFAULT); + } + + protected void configurePlanningMode(PlanningMode planningMode) { + configurePlanningMode(tableName, planningMode); + } + + protected void configurePlanningMode(String table, PlanningMode planningMode) { + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' '%s', '%s' '%s')", + table, + TableProperties.DATA_PLANNING_MODE, + planningMode.modeName(), + TableProperties.DELETE_PLANNING_MODE, + planningMode.modeName()); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestChangelogIterator.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestChangelogIterator.java new file mode 100644 index 000000000000..bd9832f7d674 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestChangelogIterator.java @@ -0,0 +1,359 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import org.apache.iceberg.ChangelogOperation; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.junit.jupiter.api.Test; + +public class TestChangelogIterator extends SparkTestHelperBase { + private static final String DELETE = ChangelogOperation.DELETE.name(); + private static final String INSERT = ChangelogOperation.INSERT.name(); + private static final String UPDATE_BEFORE = ChangelogOperation.UPDATE_BEFORE.name(); + private static final String UPDATE_AFTER = ChangelogOperation.UPDATE_AFTER.name(); + + private static final StructType SCHEMA = + new StructType( + new StructField[] { + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("name", DataTypes.StringType, false, Metadata.empty()), + new StructField("data", DataTypes.StringType, true, Metadata.empty()), + new StructField( + MetadataColumns.CHANGE_TYPE.name(), DataTypes.StringType, false, Metadata.empty()), + new StructField( + MetadataColumns.CHANGE_ORDINAL.name(), + DataTypes.IntegerType, + false, + Metadata.empty()), + new StructField( + MetadataColumns.COMMIT_SNAPSHOT_ID.name(), + DataTypes.LongType, + false, + Metadata.empty()) + }); + private static final String[] IDENTIFIER_FIELDS = new String[] {"id", "name"}; + + private enum RowType { + DELETED, + INSERTED, + CARRY_OVER, + UPDATED + } + + @Test + public void testIterator() { + List permutations = Lists.newArrayList(); + // generate 24 permutations + permute( + Arrays.asList(RowType.DELETED, RowType.INSERTED, RowType.CARRY_OVER, RowType.UPDATED), + 0, + permutations); + assertThat(permutations).hasSize(24); + + for (Object[] permutation : permutations) { + validate(permutation); + } + } + + private void validate(Object[] permutation) { + List rows = Lists.newArrayList(); + List expectedRows = Lists.newArrayList(); + for (int i = 0; i < permutation.length; i++) { + rows.addAll(toOriginalRows((RowType) permutation[i], i)); + expectedRows.addAll(toExpectedRows((RowType) permutation[i], i)); + } + + Iterator iterator = + ChangelogIterator.computeUpdates(rows.iterator(), SCHEMA, IDENTIFIER_FIELDS); + List result = Lists.newArrayList(iterator); + assertEquals("Rows should match", expectedRows, rowsToJava(result)); + } + + private List toOriginalRows(RowType rowType, int index) { + switch (rowType) { + case DELETED: + return Lists.newArrayList( + new GenericRowWithSchema(new Object[] {index, "b", "data", DELETE, 0, 0}, null)); + case INSERTED: + return Lists.newArrayList( + new GenericRowWithSchema(new Object[] {index, "c", "data", INSERT, 0, 0}, null)); + case CARRY_OVER: + return Lists.newArrayList( + new GenericRowWithSchema(new Object[] {index, "d", "data", DELETE, 0, 0}, null), + new GenericRowWithSchema(new Object[] {index, "d", "data", INSERT, 0, 0}, null)); + case UPDATED: + return Lists.newArrayList( + new GenericRowWithSchema(new Object[] {index, "a", "data", DELETE, 0, 0}, null), + new GenericRowWithSchema(new Object[] {index, "a", "new_data", INSERT, 0, 0}, null)); + default: + throw new IllegalArgumentException("Unknown row type: " + rowType); + } + } + + private List toExpectedRows(RowType rowType, int order) { + switch (rowType) { + case DELETED: + List rows = Lists.newArrayList(); + rows.add(new Object[] {order, "b", "data", DELETE, 0, 0}); + return rows; + case INSERTED: + List insertedRows = Lists.newArrayList(); + insertedRows.add(new Object[] {order, "c", "data", INSERT, 0, 0}); + return insertedRows; + case CARRY_OVER: + return Lists.newArrayList(); + case UPDATED: + return Lists.newArrayList( + new Object[] {order, "a", "data", UPDATE_BEFORE, 0, 0}, + new Object[] {order, "a", "new_data", UPDATE_AFTER, 0, 0}); + default: + throw new IllegalArgumentException("Unknown row type: " + rowType); + } + } + + private void permute(List arr, int start, List pm) { + for (int i = start; i < arr.size(); i++) { + Collections.swap(arr, i, start); + permute(arr, start + 1, pm); + Collections.swap(arr, start, i); + } + if (start == arr.size() - 1) { + pm.add(arr.toArray()); + } + } + + @Test + public void testRowsWithNullValue() { + final List rowsWithNull = + Lists.newArrayList( + new GenericRowWithSchema(new Object[] {2, null, null, DELETE, 0, 0}, null), + new GenericRowWithSchema(new Object[] {3, null, null, INSERT, 0, 0}, null), + new GenericRowWithSchema(new Object[] {4, null, null, DELETE, 0, 0}, null), + new GenericRowWithSchema(new Object[] {4, null, null, INSERT, 0, 0}, null), + // mixed null and non-null value in non-identifier columns + new GenericRowWithSchema(new Object[] {5, null, null, DELETE, 0, 0}, null), + new GenericRowWithSchema(new Object[] {5, null, "data", INSERT, 0, 0}, null), + // mixed null and non-null value in identifier columns + new GenericRowWithSchema(new Object[] {6, null, null, DELETE, 0, 0}, null), + new GenericRowWithSchema(new Object[] {6, "name", null, INSERT, 0, 0}, null)); + + Iterator iterator = + ChangelogIterator.computeUpdates(rowsWithNull.iterator(), SCHEMA, IDENTIFIER_FIELDS); + List result = Lists.newArrayList(iterator); + + assertEquals( + "Rows should match", + Lists.newArrayList( + new Object[] {2, null, null, DELETE, 0, 0}, + new Object[] {3, null, null, INSERT, 0, 0}, + new Object[] {5, null, null, UPDATE_BEFORE, 0, 0}, + new Object[] {5, null, "data", UPDATE_AFTER, 0, 0}, + new Object[] {6, null, null, DELETE, 0, 0}, + new Object[] {6, "name", null, INSERT, 0, 0}), + rowsToJava(result)); + } + + @Test + public void testUpdatedRowsWithDuplication() { + List rowsWithDuplication = + Lists.newArrayList( + // two rows with same identifier fields(id, name) + new GenericRowWithSchema(new Object[] {1, "a", "data", DELETE, 0, 0}, null), + new GenericRowWithSchema(new Object[] {1, "a", "data", DELETE, 0, 0}, null), + new GenericRowWithSchema(new Object[] {1, "a", "new_data", INSERT, 0, 0}, null), + new GenericRowWithSchema(new Object[] {1, "a", "new_data", INSERT, 0, 0}, null)); + + Iterator iterator = + ChangelogIterator.computeUpdates(rowsWithDuplication.iterator(), SCHEMA, IDENTIFIER_FIELDS); + + assertThatThrownBy(() -> Lists.newArrayList(iterator)) + .isInstanceOf(IllegalStateException.class) + .hasMessage( + "Cannot compute updates because there are multiple rows with the same identifier fields([id,name]). Please make sure the rows are unique."); + + // still allow extra insert rows + rowsWithDuplication = + Lists.newArrayList( + new GenericRowWithSchema(new Object[] {1, "a", "data", DELETE, 0, 0}, null), + new GenericRowWithSchema(new Object[] {1, "a", "new_data1", INSERT, 0, 0}, null), + new GenericRowWithSchema(new Object[] {1, "a", "new_data2", INSERT, 0, 0}, null)); + + Iterator iterator1 = + ChangelogIterator.computeUpdates(rowsWithDuplication.iterator(), SCHEMA, IDENTIFIER_FIELDS); + + assertEquals( + "Rows should match.", + Lists.newArrayList( + new Object[] {1, "a", "data", UPDATE_BEFORE, 0, 0}, + new Object[] {1, "a", "new_data1", UPDATE_AFTER, 0, 0}, + new Object[] {1, "a", "new_data2", INSERT, 0, 0}), + rowsToJava(Lists.newArrayList(iterator1))); + } + + @Test + public void testCarryRowsRemoveWithDuplicates() { + // assume rows are sorted by id and change type + List rowsWithDuplication = + Lists.newArrayList( + // keep all delete rows for id 0 and id 1 since there is no insert row for them + new GenericRowWithSchema(new Object[] {0, "a", "data", DELETE, 0, 0}, null), + new GenericRowWithSchema(new Object[] {0, "a", "data", DELETE, 0, 0}, null), + new GenericRowWithSchema(new Object[] {0, "a", "data", DELETE, 0, 0}, null), + new GenericRowWithSchema(new Object[] {1, "a", "old_data", DELETE, 0, 0}, null), + new GenericRowWithSchema(new Object[] {1, "a", "old_data", DELETE, 0, 0}, null), + // the same number of delete and insert rows for id 2 + new GenericRowWithSchema(new Object[] {2, "a", "data", DELETE, 0, 0}, null), + new GenericRowWithSchema(new Object[] {2, "a", "data", DELETE, 0, 0}, null), + new GenericRowWithSchema(new Object[] {2, "a", "data", INSERT, 0, 0}, null), + new GenericRowWithSchema(new Object[] {2, "a", "data", INSERT, 0, 0}, null), + new GenericRowWithSchema(new Object[] {3, "a", "new_data", INSERT, 0, 0}, null)); + + List expectedRows = + Lists.newArrayList( + new Object[] {0, "a", "data", DELETE, 0, 0}, + new Object[] {0, "a", "data", DELETE, 0, 0}, + new Object[] {0, "a", "data", DELETE, 0, 0}, + new Object[] {1, "a", "old_data", DELETE, 0, 0}, + new Object[] {1, "a", "old_data", DELETE, 0, 0}, + new Object[] {3, "a", "new_data", INSERT, 0, 0}); + + validateIterators(rowsWithDuplication, expectedRows); + } + + @Test + public void testCarryRowsRemoveLessInsertRows() { + // less insert rows than delete rows + List rowsWithDuplication = + Lists.newArrayList( + new GenericRowWithSchema(new Object[] {1, "d", "data", DELETE, 0, 0}, null), + new GenericRowWithSchema(new Object[] {1, "d", "data", DELETE, 0, 0}, null), + new GenericRowWithSchema(new Object[] {1, "d", "data", INSERT, 0, 0}, null), + new GenericRowWithSchema(new Object[] {2, "d", "data", INSERT, 0, 0}, null)); + + List expectedRows = + Lists.newArrayList( + new Object[] {1, "d", "data", DELETE, 0, 0}, + new Object[] {2, "d", "data", INSERT, 0, 0}); + + validateIterators(rowsWithDuplication, expectedRows); + } + + @Test + public void testCarryRowsRemoveMoreInsertRows() { + List rowsWithDuplication = + Lists.newArrayList( + new GenericRowWithSchema(new Object[] {0, "d", "data", DELETE, 0, 0}, null), + new GenericRowWithSchema(new Object[] {1, "d", "data", DELETE, 0, 0}, null), + new GenericRowWithSchema(new Object[] {1, "d", "data", DELETE, 0, 0}, null), + new GenericRowWithSchema(new Object[] {1, "d", "data", DELETE, 0, 0}, null), + // more insert rows than delete rows, should keep extra insert rows + new GenericRowWithSchema(new Object[] {1, "d", "data", INSERT, 0, 0}, null), + new GenericRowWithSchema(new Object[] {1, "d", "data", INSERT, 0, 0}, null), + new GenericRowWithSchema(new Object[] {1, "d", "data", INSERT, 0, 0}, null), + new GenericRowWithSchema(new Object[] {1, "d", "data", INSERT, 0, 0}, null)); + + List expectedRows = + Lists.newArrayList( + new Object[] {0, "d", "data", DELETE, 0, 0}, + new Object[] {1, "d", "data", INSERT, 0, 0}); + + validateIterators(rowsWithDuplication, expectedRows); + } + + @Test + public void testCarryRowsRemoveNoInsertRows() { + // no insert row + List rowsWithDuplication = + Lists.newArrayList( + // next two rows are identical + new GenericRowWithSchema(new Object[] {1, "d", "data", DELETE, 0, 0}, null), + new GenericRowWithSchema(new Object[] {1, "d", "data", DELETE, 0, 0}, null)); + + List expectedRows = + Lists.newArrayList( + new Object[] {1, "d", "data", DELETE, 0, 0}, + new Object[] {1, "d", "data", DELETE, 0, 0}); + + validateIterators(rowsWithDuplication, expectedRows); + } + + private void validateIterators(List rowsWithDuplication, List expectedRows) { + Iterator iterator = + ChangelogIterator.removeCarryovers(rowsWithDuplication.iterator(), SCHEMA); + List result = Lists.newArrayList(iterator); + + assertEquals("Rows should match.", expectedRows, rowsToJava(result)); + + iterator = ChangelogIterator.removeNetCarryovers(rowsWithDuplication.iterator(), SCHEMA); + result = Lists.newArrayList(iterator); + + assertEquals("Rows should match.", expectedRows, rowsToJava(result)); + } + + @Test + public void testRemoveNetCarryovers() { + List rowsWithDuplication = + Lists.newArrayList( + // this row are different from other rows, it is a net change, should be kept + new GenericRowWithSchema(new Object[] {0, "d", "data", DELETE, 0, 0}, null), + // a pair of delete and insert rows, should be removed + new GenericRowWithSchema(new Object[] {1, "d", "data", DELETE, 0, 0}, null), + new GenericRowWithSchema(new Object[] {1, "d", "data", INSERT, 0, 0}, null), + // 2 delete rows and 2 insert rows, should be removed + new GenericRowWithSchema(new Object[] {1, "d", "data", DELETE, 1, 1}, null), + new GenericRowWithSchema(new Object[] {1, "d", "data", DELETE, 1, 1}, null), + new GenericRowWithSchema(new Object[] {1, "d", "data", INSERT, 1, 1}, null), + new GenericRowWithSchema(new Object[] {1, "d", "data", INSERT, 1, 1}, null), + // a pair of insert and delete rows across snapshots, should be removed + new GenericRowWithSchema(new Object[] {1, "d", "data", INSERT, 2, 2}, null), + new GenericRowWithSchema(new Object[] {1, "d", "data", DELETE, 3, 3}, null), + // extra insert rows, they are net changes, should be kept + new GenericRowWithSchema(new Object[] {1, "d", "data", INSERT, 4, 4}, null), + new GenericRowWithSchema(new Object[] {1, "d", "data", INSERT, 4, 4}, null), + // different key, net changes, should be kept + new GenericRowWithSchema(new Object[] {2, "d", "data", DELETE, 4, 4}, null)); + + List expectedRows = + Lists.newArrayList( + new Object[] {0, "d", "data", DELETE, 0, 0}, + new Object[] {1, "d", "data", INSERT, 4, 4}, + new Object[] {1, "d", "data", INSERT, 4, 4}, + new Object[] {2, "d", "data", DELETE, 4, 4}); + + Iterator iterator = + ChangelogIterator.removeNetCarryovers(rowsWithDuplication.iterator(), SCHEMA); + List result = Lists.newArrayList(iterator); + + assertEquals("Rows should match.", expectedRows, rowsToJava(result)); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestFileRewriteCoordinator.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestFileRewriteCoordinator.java new file mode 100644 index 000000000000..666634a06c02 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestFileRewriteCoordinator.java @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.stream.Collectors; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Table; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.iceberg.util.DataFileSet; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestFileRewriteCoordinator extends CatalogTestBase { + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testBinPackRewrite() throws NoSuchTableException, IOException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + + Dataset df = newDF(1000); + df.coalesce(1).writeTo(tableName).append(); + df.coalesce(1).writeTo(tableName).append(); + df.coalesce(1).writeTo(tableName).append(); + df.coalesce(1).writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).as("Should produce 4 snapshots").hasSize(4); + + Dataset fileDF = + spark.read().format("iceberg").load(tableName(tableIdent.name() + ".files")); + List fileSizes = fileDF.select("file_size_in_bytes").as(Encoders.LONG()).collectAsList(); + long avgFileSize = fileSizes.stream().mapToLong(i -> i).sum() / fileSizes.size(); + + try (CloseableIterable fileScanTasks = table.newScan().planFiles()) { + String fileSetID = UUID.randomUUID().toString(); + + ScanTaskSetManager taskSetManager = ScanTaskSetManager.get(); + taskSetManager.stageTasks(table, fileSetID, Lists.newArrayList(fileScanTasks)); + + // read and pack original 4 files into 2 splits + Dataset scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, fileSetID) + .option(SparkReadOptions.SPLIT_SIZE, Long.toString(avgFileSize * 2)) + .option(SparkReadOptions.FILE_OPEN_COST, "0") + .load(tableName); + + // write the packed data into new files where each split becomes a new file + scanDF + .writeTo(tableName) + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, fileSetID) + .append(); + + // commit the rewrite + FileRewriteCoordinator rewriteCoordinator = FileRewriteCoordinator.get(); + Set rewrittenFiles = + taskSetManager.fetchTasks(table, fileSetID).stream() + .map(t -> t.asFileScanTask().file()) + .collect(Collectors.toCollection(DataFileSet::create)); + Set addedFiles = rewriteCoordinator.fetchNewFiles(table, fileSetID); + table.newRewrite().rewriteFiles(rewrittenFiles, addedFiles).commit(); + } + + table.refresh(); + + Map summary = table.currentSnapshot().summary(); + assertThat(summary.get("deleted-data-files")) + .as("Deleted files count must match") + .isEqualTo("4"); + assertThat(summary.get("added-data-files")).as("Added files count must match").isEqualTo("2"); + + Object rowCount = scalarSql("SELECT count(*) FROM %s", tableName); + assertThat(rowCount).as("Row count must match").isEqualTo(4000L); + } + + @TestTemplate + public void testSortRewrite() throws NoSuchTableException, IOException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + + Dataset df = newDF(1000); + df.coalesce(1).writeTo(tableName).append(); + df.coalesce(1).writeTo(tableName).append(); + df.coalesce(1).writeTo(tableName).append(); + df.coalesce(1).writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).as("Should produce 4 snapshots").hasSize(4); + + try (CloseableIterable fileScanTasks = table.newScan().planFiles()) { + String fileSetID = UUID.randomUUID().toString(); + + ScanTaskSetManager taskSetManager = ScanTaskSetManager.get(); + taskSetManager.stageTasks(table, fileSetID, Lists.newArrayList(fileScanTasks)); + + // read original 4 files as 4 splits + Dataset scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, fileSetID) + .option(SparkReadOptions.SPLIT_SIZE, "134217728") + .option(SparkReadOptions.FILE_OPEN_COST, "134217728") + .load(tableName); + + // make sure we disable AQE and set the number of shuffle partitions as the target num files + ImmutableMap sqlConf = + ImmutableMap.of( + "spark.sql.shuffle.partitions", "2", + "spark.sql.adaptive.enabled", "false"); + + withSQLConf( + sqlConf, + () -> { + try { + // write new files with sorted records + scanDF + .sort("id") + .writeTo(tableName) + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, fileSetID) + .append(); + } catch (NoSuchTableException e) { + throw new RuntimeException("Could not replace files", e); + } + }); + + // commit the rewrite + FileRewriteCoordinator rewriteCoordinator = FileRewriteCoordinator.get(); + Set rewrittenFiles = + taskSetManager.fetchTasks(table, fileSetID).stream() + .map(t -> t.asFileScanTask().file()) + .collect(Collectors.toCollection(DataFileSet::create)); + Set addedFiles = rewriteCoordinator.fetchNewFiles(table, fileSetID); + table.newRewrite().rewriteFiles(rewrittenFiles, addedFiles).commit(); + } + + table.refresh(); + + Map summary = table.currentSnapshot().summary(); + assertThat(summary.get("deleted-data-files")) + .as("Deleted files count must match") + .isEqualTo("4"); + assertThat(summary.get("added-data-files")).as("Added files count must match").isEqualTo("2"); + + Object rowCount = scalarSql("SELECT count(*) FROM %s", tableName); + assertThat(rowCount).as("Row count must match").isEqualTo(4000L); + } + + @TestTemplate + public void testCommitMultipleRewrites() throws NoSuchTableException, IOException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + + Dataset df = newDF(1000); + + // add first two files + df.coalesce(1).writeTo(tableName).append(); + df.coalesce(1).writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + + String firstFileSetID = UUID.randomUUID().toString(); + long firstFileSetSnapshotId = table.currentSnapshot().snapshotId(); + + ScanTaskSetManager taskSetManager = ScanTaskSetManager.get(); + + try (CloseableIterable tasks = table.newScan().planFiles()) { + // stage first 2 files for compaction + taskSetManager.stageTasks(table, firstFileSetID, Lists.newArrayList(tasks)); + } + + // add two more files + df.coalesce(1).writeTo(tableName).append(); + df.coalesce(1).writeTo(tableName).append(); + + table.refresh(); + + String secondFileSetID = UUID.randomUUID().toString(); + + try (CloseableIterable tasks = + table.newScan().appendsAfter(firstFileSetSnapshotId).planFiles()) { + // stage 2 more files for compaction + taskSetManager.stageTasks(table, secondFileSetID, Lists.newArrayList(tasks)); + } + + ImmutableSet fileSetIDs = ImmutableSet.of(firstFileSetID, secondFileSetID); + + for (String fileSetID : fileSetIDs) { + // read and pack 2 files into 1 split + Dataset scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, fileSetID) + .option(SparkReadOptions.SPLIT_SIZE, Long.MAX_VALUE) + .load(tableName); + + // write the combined data as one file + scanDF + .writeTo(tableName) + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, fileSetID) + .append(); + } + + // commit both rewrites at the same time + FileRewriteCoordinator rewriteCoordinator = FileRewriteCoordinator.get(); + Set rewrittenFiles = + fileSetIDs.stream() + .flatMap(fileSetID -> taskSetManager.fetchTasks(table, fileSetID).stream()) + .map(t -> t.asFileScanTask().file()) + .collect(Collectors.toSet()); + Set addedFiles = + fileSetIDs.stream() + .flatMap(fileSetID -> rewriteCoordinator.fetchNewFiles(table, fileSetID).stream()) + .collect(Collectors.toCollection(DataFileSet::create)); + table.newRewrite().rewriteFiles(rewrittenFiles, addedFiles).commit(); + + table.refresh(); + + assertThat(table.snapshots()).as("Should produce 5 snapshots").hasSize(5); + + Map summary = table.currentSnapshot().summary(); + assertThat(summary.get("deleted-data-files")) + .as("Deleted files count must match") + .isEqualTo("4"); + assertThat(summary.get("added-data-files")).as("Added files count must match").isEqualTo("2"); + + Object rowCount = scalarSql("SELECT count(*) FROM %s", tableName); + assertThat(rowCount).as("Row count must match").isEqualTo(4000L); + } + + private Dataset newDF(int numRecords) { + List data = Lists.newArrayListWithExpectedSize(numRecords); + for (int index = 0; index < numRecords; index++) { + data.add(new SimpleRecord(index, Integer.toString(index))); + } + return spark.createDataFrame(data, SimpleRecord.class); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestFunctionCatalog.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestFunctionCatalog.java new file mode 100644 index 000000000000..5f160bcf10f8 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestFunctionCatalog.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import org.apache.iceberg.IcebergBuild; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.functions.IcebergVersionFunction; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException; +import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; +import org.apache.spark.sql.connector.catalog.FunctionCatalog; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestFunctionCatalog extends TestBaseWithCatalog { + private static final String[] EMPTY_NAMESPACE = new String[] {}; + private static final String[] SYSTEM_NAMESPACE = new String[] {"system"}; + private static final String[] DEFAULT_NAMESPACE = new String[] {"default"}; + private static final String[] DB_NAMESPACE = new String[] {"db"}; + private FunctionCatalog asFunctionCatalog; + + @BeforeEach + public void before() { + super.before(); + this.asFunctionCatalog = castToFunctionCatalog(catalogName); + sql("CREATE NAMESPACE IF NOT EXISTS %s", catalogName + ".default"); + } + + @AfterEach + public void dropDefaultNamespace() { + sql("DROP NAMESPACE IF EXISTS %s", catalogName + ".default"); + } + + @TestTemplate + public void testListFunctionsViaCatalog() throws NoSuchNamespaceException { + assertThat(asFunctionCatalog.listFunctions(EMPTY_NAMESPACE)) + .anyMatch(func -> "iceberg_version".equals(func.name())); + + assertThat(asFunctionCatalog.listFunctions(SYSTEM_NAMESPACE)) + .anyMatch(func -> "iceberg_version".equals(func.name())); + + assertThat(asFunctionCatalog.listFunctions(DEFAULT_NAMESPACE)) + .as("Listing functions in an existing namespace that's not system should not throw") + .isEqualTo(new Identifier[0]); + + assertThatThrownBy(() -> asFunctionCatalog.listFunctions(DB_NAMESPACE)) + .isInstanceOf(NoSuchNamespaceException.class) + .hasMessageStartingWith("[SCHEMA_NOT_FOUND] The schema `db` cannot be found."); + } + + @TestTemplate + public void testLoadFunctions() throws NoSuchFunctionException { + for (String[] namespace : ImmutableList.of(EMPTY_NAMESPACE, SYSTEM_NAMESPACE)) { + Identifier identifier = Identifier.of(namespace, "iceberg_version"); + UnboundFunction func = asFunctionCatalog.loadFunction(identifier); + + assertThat(func) + .isNotNull() + .isInstanceOf(UnboundFunction.class) + .isExactlyInstanceOf(IcebergVersionFunction.class); + } + + assertThatThrownBy( + () -> + asFunctionCatalog.loadFunction(Identifier.of(DEFAULT_NAMESPACE, "iceberg_version"))) + .isInstanceOf(NoSuchFunctionException.class) + .hasMessageStartingWith( + String.format( + "[ROUTINE_NOT_FOUND] The routine default.iceberg_version cannot be found")); + + Identifier undefinedFunction = Identifier.of(SYSTEM_NAMESPACE, "undefined_function"); + assertThatThrownBy(() -> asFunctionCatalog.loadFunction(undefinedFunction)) + .isInstanceOf(NoSuchFunctionException.class) + .hasMessageStartingWith( + String.format( + "[ROUTINE_NOT_FOUND] The routine system.undefined_function cannot be found")); + + assertThatThrownBy(() -> sql("SELECT undefined_function(1, 2)")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "[UNRESOLVED_ROUTINE] Cannot resolve routine `undefined_function` on search path"); + } + + @TestTemplate + public void testCallingFunctionInSQLEndToEnd() { + String buildVersion = IcebergBuild.version(); + + assertThat(scalarSql("SELECT %s.system.iceberg_version()", catalogName)) + .as( + "Should be able to use the Iceberg version function from the fully qualified system namespace") + .isEqualTo(buildVersion); + + assertThat(scalarSql("SELECT %s.iceberg_version()", catalogName)) + .as( + "Should be able to use the Iceberg version function when fully qualified without specifying a namespace") + .isEqualTo(buildVersion); + + sql("USE %s", catalogName); + + assertThat(scalarSql("SELECT system.iceberg_version()")) + .as( + "Should be able to call iceberg_version from system namespace without fully qualified name when using Iceberg catalog") + .isEqualTo(buildVersion); + + assertThat(scalarSql("SELECT iceberg_version()")) + .as( + "Should be able to call iceberg_version from empty namespace without fully qualified name when using Iceberg catalog") + .isEqualTo(buildVersion); + } + + private FunctionCatalog castToFunctionCatalog(String name) { + return (FunctionCatalog) spark.sessionState().catalogManager().catalog(name); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSpark3Util.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSpark3Util.java new file mode 100644 index 000000000000..6f900ffebb10 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSpark3Util.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.apache.iceberg.NullOrder.NULLS_FIRST; +import static org.apache.iceberg.NullOrder.NULLS_LAST; +import static org.apache.iceberg.expressions.Expressions.and; +import static org.apache.iceberg.expressions.Expressions.bucket; +import static org.apache.iceberg.expressions.Expressions.day; +import static org.apache.iceberg.expressions.Expressions.equal; +import static org.apache.iceberg.expressions.Expressions.greaterThan; +import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual; +import static org.apache.iceberg.expressions.Expressions.hour; +import static org.apache.iceberg.expressions.Expressions.in; +import static org.apache.iceberg.expressions.Expressions.lessThan; +import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual; +import static org.apache.iceberg.expressions.Expressions.month; +import static org.apache.iceberg.expressions.Expressions.notIn; +import static org.apache.iceberg.expressions.Expressions.truncate; +import static org.apache.iceberg.expressions.Expressions.year; +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.assertj.core.api.Assertions.assertThat; + +import org.apache.iceberg.CachingCatalog; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.SortOrderParser; +import org.apache.iceberg.Table; +import org.apache.iceberg.catalog.Catalog; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.types.Types; +import org.junit.jupiter.api.Test; + +public class TestSpark3Util extends TestBase { + @Test + public void testDescribeSortOrder() { + Schema schema = + new Schema( + required(1, "data", Types.StringType.get()), + required(2, "time", Types.TimestampType.withoutZone())); + + assertThat(Spark3Util.describe(buildSortOrder("Identity", schema, 1))) + .as("Sort order isn't correct.") + .isEqualTo("data DESC NULLS FIRST"); + + assertThat(Spark3Util.describe(buildSortOrder("bucket[1]", schema, 1))) + .as("Sort order isn't correct.") + .isEqualTo("bucket(1, data) DESC NULLS FIRST"); + + assertThat(Spark3Util.describe(buildSortOrder("truncate[3]", schema, 1))) + .as("Sort order isn't correct.") + .isEqualTo("truncate(data, 3) DESC NULLS FIRST"); + + assertThat(Spark3Util.describe(buildSortOrder("year", schema, 2))) + .as("Sort order isn't correct.") + .isEqualTo("years(time) DESC NULLS FIRST"); + + assertThat(Spark3Util.describe(buildSortOrder("month", schema, 2))) + .as("Sort order isn't correct.") + .isEqualTo("months(time) DESC NULLS FIRST"); + + assertThat(Spark3Util.describe(buildSortOrder("day", schema, 2))) + .as("Sort order isn't correct.") + .isEqualTo("days(time) DESC NULLS FIRST"); + + assertThat(Spark3Util.describe(buildSortOrder("hour", schema, 2))) + .as("Sort order isn't correct.") + .isEqualTo("hours(time) DESC NULLS FIRST"); + + assertThat(Spark3Util.describe(buildSortOrder("unknown", schema, 1))) + .as("Sort order isn't correct.") + .isEqualTo("unknown(data) DESC NULLS FIRST"); + + // multiple sort orders + SortOrder multiOrder = + SortOrder.builderFor(schema).asc("time", NULLS_FIRST).asc("data", NULLS_LAST).build(); + assertThat(Spark3Util.describe(multiOrder)) + .as("Sort order isn't correct.") + .isEqualTo("time ASC NULLS FIRST, data ASC NULLS LAST"); + } + + @Test + public void testDescribeSchema() { + Schema schema = + new Schema( + required(1, "data", Types.ListType.ofRequired(2, Types.StringType.get())), + optional( + 3, + "pairs", + Types.MapType.ofOptional(4, 5, Types.StringType.get(), Types.LongType.get())), + required(6, "time", Types.TimestampType.withoutZone())); + + assertThat(Spark3Util.describe(schema)) + .as("Schema description isn't correct.") + .isEqualTo( + "struct not null,pairs: map,time: timestamp not null>"); + } + + @Test + public void testLoadIcebergTable() throws Exception { + spark.conf().set("spark.sql.catalog.hive", SparkCatalog.class.getName()); + spark.conf().set("spark.sql.catalog.hive.type", "hive"); + spark.conf().set("spark.sql.catalog.hive.default-namespace", "default"); + + String tableFullName = "hive.default.tbl"; + sql("CREATE TABLE %s (c1 bigint, c2 string, c3 string) USING iceberg", tableFullName); + + Table table = Spark3Util.loadIcebergTable(spark, tableFullName); + assertThat(table.name()).isEqualTo(tableFullName); + } + + @Test + public void testLoadIcebergCatalog() throws Exception { + spark.conf().set("spark.sql.catalog.test_cat", SparkCatalog.class.getName()); + spark.conf().set("spark.sql.catalog.test_cat.type", "hive"); + Catalog catalog = Spark3Util.loadIcebergCatalog(spark, "test_cat"); + assertThat(catalog) + .as("Should retrieve underlying catalog class") + .isInstanceOf(CachingCatalog.class); + } + + @Test + public void testDescribeExpression() { + Expression refExpression = equal("id", 1); + assertThat(Spark3Util.describe(refExpression)).isEqualTo("id = 1"); + + Expression yearExpression = greaterThan(year("ts"), 10); + assertThat(Spark3Util.describe(yearExpression)).isEqualTo("year(ts) > 10"); + + Expression monthExpression = greaterThanOrEqual(month("ts"), 10); + assertThat(Spark3Util.describe(monthExpression)).isEqualTo("month(ts) >= 10"); + + Expression dayExpression = lessThan(day("ts"), 10); + assertThat(Spark3Util.describe(dayExpression)).isEqualTo("day(ts) < 10"); + + Expression hourExpression = lessThanOrEqual(hour("ts"), 10); + assertThat(Spark3Util.describe(hourExpression)).isEqualTo("hour(ts) <= 10"); + + Expression bucketExpression = in(bucket("id", 5), 3); + assertThat(Spark3Util.describe(bucketExpression)).isEqualTo("bucket[5](id) IN (3)"); + + Expression truncateExpression = notIn(truncate("name", 3), "abc"); + assertThat(Spark3Util.describe(truncateExpression)) + .isEqualTo("truncate[3](name) NOT IN ('abc')"); + + Expression andExpression = and(refExpression, yearExpression); + assertThat(Spark3Util.describe(andExpression)).isEqualTo("(id = 1 AND year(ts) > 10)"); + } + + private SortOrder buildSortOrder(String transform, Schema schema, int sourceId) { + String jsonString = + "{\n" + + " \"order-id\" : 10,\n" + + " \"fields\" : [ {\n" + + " \"transform\" : \"" + + transform + + "\",\n" + + " \"source-id\" : " + + sourceId + + ",\n" + + " \"direction\" : \"desc\",\n" + + " \"null-order\" : \"nulls-first\"\n" + + " } ]\n" + + "}"; + + return SortOrderParser.fromJson(schema, jsonString); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkCachedTableCatalog.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkCachedTableCatalog.java new file mode 100644 index 000000000000..eaf230865957 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkCachedTableCatalog.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import org.apache.iceberg.Parameters; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.TestTemplate; + +public class TestSparkCachedTableCatalog extends TestBaseWithCatalog { + + private static final SparkTableCache TABLE_CACHE = SparkTableCache.get(); + + @BeforeAll + public static void setupCachedTableCatalog() { + spark.conf().set("spark.sql.catalog.testcache", SparkCachedTableCatalog.class.getName()); + } + + @AfterAll + public static void unsetCachedTableCatalog() { + spark.conf().unset("spark.sql.catalog.testcache"); + } + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + protected static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.HIVE.catalogName(), + SparkCatalogConfig.HIVE.implementation(), + SparkCatalogConfig.HIVE.properties() + }, + }; + } + + @TestTemplate + public void testTimeTravel() { + sql("CREATE TABLE %s (id INT, dep STRING) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + + table.refresh(); + Snapshot firstSnapshot = table.currentSnapshot(); + waitUntilAfter(firstSnapshot.timestampMillis()); + + sql("INSERT INTO TABLE %s VALUES (2, 'hr')", tableName); + + table.refresh(); + Snapshot secondSnapshot = table.currentSnapshot(); + waitUntilAfter(secondSnapshot.timestampMillis()); + + sql("INSERT INTO TABLE %s VALUES (3, 'hr')", tableName); + + table.refresh(); + + try { + TABLE_CACHE.add("key", table); + + assertEquals( + "Should have expected rows in 3rd snapshot", + ImmutableList.of(row(1, "hr"), row(2, "hr"), row(3, "hr")), + sql("SELECT * FROM testcache.key ORDER BY id")); + + assertEquals( + "Should have expected rows in 2nd snapshot", + ImmutableList.of(row(1, "hr"), row(2, "hr")), + sql( + "SELECT * FROM testcache.`key#at_timestamp_%s` ORDER BY id", + secondSnapshot.timestampMillis())); + + assertEquals( + "Should have expected rows in 1st snapshot", + ImmutableList.of(row(1, "hr")), + sql( + "SELECT * FROM testcache.`key#snapshot_id_%d` ORDER BY id", + firstSnapshot.snapshotId())); + + } finally { + TABLE_CACHE.remove("key"); + } + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkCatalogOperations.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkCatalogOperations.java new file mode 100644 index 000000000000..186042283bdb --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkCatalogOperations.java @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.concurrent.ThreadLocalRandom; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.Schema; +import org.apache.iceberg.catalog.Catalog; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableChange; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestSparkCatalogOperations extends CatalogTestBase { + private static final boolean USE_NULLABLE_QUERY_SCHEMA = + ThreadLocalRandom.current().nextBoolean(); + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + protected static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.HIVE.catalogName(), + SparkCatalogConfig.HIVE.implementation(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "use-nullable-query-schema", Boolean.toString(USE_NULLABLE_QUERY_SCHEMA)) + }, + { + SparkCatalogConfig.HADOOP.catalogName(), + SparkCatalogConfig.HADOOP.implementation(), + ImmutableMap.of( + "type", + "hadoop", + "cache-enabled", + "false", + "use-nullable-query-schema", + Boolean.toString(USE_NULLABLE_QUERY_SCHEMA)) + }, + { + SparkCatalogConfig.SPARK.catalogName(), + SparkCatalogConfig.SPARK.implementation(), + ImmutableMap.of( + "type", + "hive", + "default-namespace", + "default", + "parquet-enabled", + "true", + "cache-enabled", + "false", // Spark will delete tables using v1, leaving the cache out of sync + "use-nullable-query-schema", + Boolean.toString(USE_NULLABLE_QUERY_SCHEMA)), + } + }; + } + + @BeforeEach + public void createTable() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + } + + @AfterEach + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testAlterTable() throws NoSuchTableException { + BaseCatalog catalog = (BaseCatalog) spark.sessionState().catalogManager().catalog(catalogName); + Identifier identifier = Identifier.of(tableIdent.namespace().levels(), tableIdent.name()); + + String fieldName = "location"; + String propsKey = "note"; + String propsValue = "jazz"; + Table table = + catalog.alterTable( + identifier, + TableChange.addColumn(new String[] {fieldName}, DataTypes.StringType, true), + TableChange.setProperty(propsKey, propsValue)); + + assertThat(table).as("Should return updated table").isNotNull(); + + StructField expectedField = DataTypes.createStructField(fieldName, DataTypes.StringType, true); + assertThat(table.schema().fields()[2]) + .as("Adding a column to a table should return the updated table with the new column") + .isEqualTo(expectedField); + + assertThat(table.properties()) + .as( + "Adding a property to a table should return the updated table with the new property with the new correct value") + .containsEntry(propsKey, propsValue); + } + + @TestTemplate + public void testInvalidateTable() { + // load table to CachingCatalog + sql("SELECT count(1) FROM %s", tableName); + + // recreate table from another catalog or program + Catalog anotherCatalog = validationCatalog; + Schema schema = anotherCatalog.loadTable(tableIdent).schema(); + anotherCatalog.dropTable(tableIdent); + anotherCatalog.createTable(tableIdent, schema); + + // invalidate and reload table + sql("REFRESH TABLE %s", tableName); + sql("SELECT count(1) FROM %s", tableName); + } + + @TestTemplate + public void testCTASUseNullableQuerySchema() { + sql("INSERT INTO %s VALUES(1, 'abc'), (2, null)", tableName); + + String ctasTableName = tableName("ctas_table"); + + sql("CREATE TABLE %s USING iceberg AS SELECT * FROM %s", ctasTableName, tableName); + + org.apache.iceberg.Table ctasTable = + validationCatalog.loadTable(TableIdentifier.parse("default.ctas_table")); + + Schema expectedSchema = + new Schema( + USE_NULLABLE_QUERY_SCHEMA + ? Types.NestedField.optional(1, "id", Types.LongType.get()) + : Types.NestedField.required(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get())); + + assertThat(ctasTable.schema().asStruct()) + .as("Should have expected schema") + .isEqualTo(expectedSchema.asStruct()); + + sql("DROP TABLE IF EXISTS %s", ctasTableName); + } + + @TestTemplate + public void testRTASUseNullableQuerySchema() { + sql("INSERT INTO %s VALUES(1, 'abc'), (2, null)", tableName); + + String rtasTableName = tableName("rtas_table"); + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", rtasTableName); + + sql("REPLACE TABLE %s USING iceberg AS SELECT * FROM %s", rtasTableName, tableName); + + org.apache.iceberg.Table rtasTable = + validationCatalog.loadTable(TableIdentifier.parse("default.rtas_table")); + + Schema expectedSchema = + new Schema( + USE_NULLABLE_QUERY_SCHEMA + ? Types.NestedField.optional(1, "id", Types.LongType.get()) + : Types.NestedField.required(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get())); + + assertThat(rtasTable.schema().asStruct()) + .as("Should have expected schema") + .isEqualTo(expectedSchema.asStruct()); + + assertEquals( + "Should have rows matching the source table", + sql("SELECT * FROM %s ORDER BY id", tableName), + sql("SELECT * FROM %s ORDER BY id", rtasTableName)); + + sql("DROP TABLE IF EXISTS %s", rtasTableName); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkCompressionUtil.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkCompressionUtil.java new file mode 100644 index 000000000000..aa329efbbad5 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkCompressionUtil.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.apache.iceberg.FileFormat.AVRO; +import static org.apache.iceberg.FileFormat.METADATA; +import static org.apache.iceberg.FileFormat.ORC; +import static org.apache.iceberg.FileFormat.PARQUET; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import org.apache.iceberg.FileFormat; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +import org.apache.spark.internal.config.package$; +import org.apache.spark.sql.SparkSession; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class TestSparkCompressionUtil { + + private SparkSession spark; + private SparkConf sparkConf; + + @BeforeEach + public void initSpark() { + this.spark = mock(SparkSession.class); + this.sparkConf = mock(SparkConf.class); + + SparkContext sparkContext = mock(SparkContext.class); + + when(spark.sparkContext()).thenReturn(sparkContext); + when(sparkContext.conf()).thenReturn(sparkConf); + } + + @Test + public void testParquetCompressionRatios() { + configureShuffle("lz4", true); + + double ratio1 = shuffleCompressionRatio(PARQUET, "zstd"); + assertThat(ratio1).isEqualTo(3.0); + + double ratio2 = shuffleCompressionRatio(PARQUET, "gzip"); + assertThat(ratio2).isEqualTo(3.0); + + double ratio3 = shuffleCompressionRatio(PARQUET, "snappy"); + assertThat(ratio3).isEqualTo(2.0); + } + + @Test + public void testOrcCompressionRatios() { + configureShuffle("lz4", true); + + double ratio1 = shuffleCompressionRatio(ORC, "zlib"); + assertThat(ratio1).isEqualTo(3.0); + + double ratio2 = shuffleCompressionRatio(ORC, "lz4"); + assertThat(ratio2).isEqualTo(2.0); + } + + @Test + public void testAvroCompressionRatios() { + configureShuffle("lz4", true); + + double ratio1 = shuffleCompressionRatio(AVRO, "gzip"); + assertThat(ratio1).isEqualTo(1.5); + + double ratio2 = shuffleCompressionRatio(AVRO, "zstd"); + assertThat(ratio2).isEqualTo(1.5); + } + + @Test + public void testCodecNameNormalization() { + configureShuffle("zStD", true); + double ratio = shuffleCompressionRatio(PARQUET, "ZstD"); + assertThat(ratio).isEqualTo(2.0); + } + + @Test + public void testUnknownCodecNames() { + configureShuffle("SOME_SPARK_CODEC", true); + + double ratio1 = shuffleCompressionRatio(PARQUET, "SOME_PARQUET_CODEC"); + assertThat(ratio1).isEqualTo(2.0); + + double ratio2 = shuffleCompressionRatio(ORC, "SOME_ORC_CODEC"); + assertThat(ratio2).isEqualTo(2.0); + + double ratio3 = shuffleCompressionRatio(AVRO, "SOME_AVRO_CODEC"); + assertThat(ratio3).isEqualTo(1.0); + } + + @Test + public void testOtherFileFormats() { + configureShuffle("lz4", true); + double ratio = shuffleCompressionRatio(METADATA, "zstd"); + assertThat(ratio).isEqualTo(1.0); + } + + @Test + public void testNullFileCodec() { + configureShuffle("lz4", true); + + double ratio1 = shuffleCompressionRatio(PARQUET, null); + assertThat(ratio1).isEqualTo(2.0); + + double ratio2 = shuffleCompressionRatio(ORC, null); + assertThat(ratio2).isEqualTo(2.0); + + double ratio3 = shuffleCompressionRatio(AVRO, null); + assertThat(ratio3).isEqualTo(1.0); + } + + @Test + public void testUncompressedShuffles() { + configureShuffle("zstd", false); + + double ratio1 = shuffleCompressionRatio(PARQUET, "zstd"); + assertThat(ratio1).isEqualTo(4.0); + + double ratio2 = shuffleCompressionRatio(ORC, "zlib"); + assertThat(ratio2).isEqualTo(4.0); + + double ratio3 = shuffleCompressionRatio(AVRO, "gzip"); + assertThat(ratio3).isEqualTo(2.0); + } + + @Test + public void testSparkDefaults() { + assertThat(package$.MODULE$.SHUFFLE_COMPRESS().defaultValueString()).isEqualTo("true"); + assertThat(package$.MODULE$.IO_COMPRESSION_CODEC().defaultValueString()).isEqualTo("lz4"); + } + + private void configureShuffle(String codec, boolean compress) { + when(sparkConf.getBoolean(eq("spark.shuffle.compress"), anyBoolean())).thenReturn(compress); + when(sparkConf.get(eq("spark.io.compression.codec"), anyString())).thenReturn(codec); + } + + private double shuffleCompressionRatio(FileFormat fileFormat, String codec) { + return SparkCompressionUtil.shuffleCompressionRatio(spark, fileFormat, codec); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkDistributionAndOrderingUtil.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkDistributionAndOrderingUtil.java new file mode 100644 index 000000000000..39ef72c6bb1d --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkDistributionAndOrderingUtil.java @@ -0,0 +1,3026 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.apache.iceberg.TableProperties.DELETE_DISTRIBUTION_MODE; +import static org.apache.iceberg.TableProperties.MERGE_DISTRIBUTION_MODE; +import static org.apache.iceberg.TableProperties.SPARK_WRITE_PARTITIONED_FANOUT_ENABLED; +import static org.apache.iceberg.TableProperties.UPDATE_DISTRIBUTION_MODE; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_HASH; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_NONE; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_RANGE; +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE; +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE; +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPDATE; +import static org.assertj.core.api.Assertions.assertThat; + +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.spark.sql.connector.distributions.Distribution; +import org.apache.spark.sql.connector.distributions.Distributions; +import org.apache.spark.sql.connector.expressions.Expression; +import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.SortDirection; +import org.apache.spark.sql.connector.expressions.SortOrder; +import org.apache.spark.sql.connector.write.RowLevelOperation.Command; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestSparkDistributionAndOrderingUtil extends TestBaseWithCatalog { + + private static final Distribution UNSPECIFIED_DISTRIBUTION = Distributions.unspecified(); + private static final Distribution FILE_CLUSTERED_DISTRIBUTION = + Distributions.clustered( + new Expression[] {Expressions.column(MetadataColumns.FILE_PATH.name())}); + private static final Distribution SPEC_ID_PARTITION_CLUSTERED_DISTRIBUTION = + Distributions.clustered( + new Expression[] { + Expressions.column(MetadataColumns.SPEC_ID.name()), + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME) + }); + private static final Distribution SPEC_ID_PARTITION_FILE_CLUSTERED_DISTRIBUTION = + Distributions.clustered( + new Expression[] { + Expressions.column(MetadataColumns.SPEC_ID.name()), + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), + Expressions.column(MetadataColumns.FILE_PATH.name()) + }); + + private static final SortOrder[] EMPTY_ORDERING = new SortOrder[] {}; + private static final SortOrder[] FILE_POSITION_ORDERING = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING) + }; + private static final SortOrder[] SPEC_ID_PARTITION_ORDERING = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING) + }; + private static final SortOrder[] SPEC_ID_PARTITION_FILE_ORDERING = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING) + }; + private static final SortOrder[] SPEC_ID_PARTITION_FILE_POSITION_ORDERING = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING) + }; + + @AfterEach + public void dropTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + // ============================================================= + // Distribution and ordering for write operations such as APPEND + // ============================================================= + // + // UNPARTITIONED UNORDERED + // ------------------------------------------------------------------------- + // write mode is NOT SET -> unspecified distribution + empty ordering + // write mode is NONE -> unspecified distribution + empty ordering + // write mode is HASH -> unspecified distribution + empty ordering + // write mode is RANGE -> unspecified distribution + empty ordering + // + // UNPARTITIONED ORDERED BY id, data + // ------------------------------------------------------------------------- + // write mode is NOT SET -> ORDER BY id, data + // write mode is NONE -> unspecified distribution + LOCALLY ORDER BY id, data + // write mode is HASH -> unspecified distribution + LOCALLY ORDER BY id, data + // write mode is RANGE -> ORDER BY id, data + // + // PARTITIONED BY date, days(ts) UNORDERED + // ------------------------------------------------------------------------- + // write mode is NOT SET -> CLUSTER BY date, days(ts) + LOCALLY ORDER BY date, days(ts) + // write mode is NOT SET (fanout) -> CLUSTER BY date, days(ts) + empty ordering + // write mode is NONE -> unspecified distribution + LOCALLY ORDERED BY date, days(ts) + // write mode is NONE (fanout) -> unspecified distribution + empty ordering + // write mode is HASH -> CLUSTER BY date, days(ts) + LOCALLY ORDER BY date, days(ts) + // write mode is HASH (fanout) -> CLUSTER BY date, days(ts) + empty ordering + // write mode is RANGE -> ORDER BY date, days(ts) + // write mode is RANGE (fanout) -> RANGE DISTRIBUTE BY date, days(ts) + empty ordering + // + // PARTITIONED BY date ORDERED BY id + // ------------------------------------------------------------------------- + // write mode is NOT SET -> ORDER BY date, id + // write mode is NONE -> unspecified distribution + LOCALLY ORDERED BY date, id + // write mode is HASH -> CLUSTER BY date + LOCALLY ORDER BY date, id + // write mode is RANGE -> ORDER BY date, id + + @TestTemplate + public void testDefaultWriteUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + checkWriteDistributionAndOrdering(table, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @TestTemplate + public void testHashWriteUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(WRITE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + checkWriteDistributionAndOrdering(table, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @TestTemplate + public void testRangeWriteUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(WRITE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + checkWriteDistributionAndOrdering(table, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @TestTemplate + public void testDefaultWriteUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkWriteDistributionAndOrdering(table, expectedDistribution, expectedOrdering); + } + + @TestTemplate + public void testHashWriteUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(WRITE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkWriteDistributionAndOrdering(table, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @TestTemplate + public void testRangeWriteUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(WRITE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkWriteDistributionAndOrdering(table, expectedDistribution, expectedOrdering); + } + + @TestTemplate + public void testDefaultWritePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + disableFanoutWriters(table); + + Expression[] expectedClustering = + new Expression[] {Expressions.identity("date"), Expressions.days("ts")}; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkWriteDistributionAndOrdering(table, expectedDistribution, expectedOrdering); + + enableFanoutWriters(table); + + checkWriteDistributionAndOrdering(table, expectedDistribution, EMPTY_ORDERING); + } + + @TestTemplate + public void testHashWritePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(WRITE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + disableFanoutWriters(table); + + Expression[] expectedClustering = + new Expression[] {Expressions.identity("date"), Expressions.days("ts")}; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkWriteDistributionAndOrdering(table, expectedDistribution, expectedOrdering); + + enableFanoutWriters(table); + + checkWriteDistributionAndOrdering(table, expectedDistribution, EMPTY_ORDERING); + } + + @TestTemplate + public void testRangeWritePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(WRITE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + disableFanoutWriters(table); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkWriteDistributionAndOrdering(table, expectedDistribution, expectedOrdering); + + enableFanoutWriters(table); + + checkWriteDistributionAndOrdering(table, expectedDistribution, EMPTY_ORDERING); + } + + @TestTemplate + public void testDefaultWritePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder().desc("id").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.DESCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkWriteDistributionAndOrdering(table, expectedDistribution, expectedOrdering); + } + + @TestTemplate + public void testHashWritePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(WRITE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + table.replaceSortOrder().asc("id").commit(); + + Expression[] expectedClustering = + new Expression[] {Expressions.identity("date"), Expressions.bucket(8, "data")}; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.bucket(8, "data"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + checkWriteDistributionAndOrdering(table, expectedDistribution, expectedOrdering); + } + + @TestTemplate + public void testRangeWritePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder().asc("id").commit(); + + disableFanoutWriters(table); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkWriteDistributionAndOrdering(table, expectedDistribution, expectedOrdering); + + enableFanoutWriters(table); + + checkWriteDistributionAndOrdering(table, expectedDistribution, expectedOrdering); + } + + // ============================================================= + // Distribution and ordering for copy-on-write DELETE operations + // ============================================================= + // + // UNPARTITIONED UNORDERED + // ------------------------------------------------------------------------- + // delete mode is NOT SET -> CLUSTER BY _file + empty ordering + // delete mode is NONE -> unspecified distribution + empty ordering + // delete mode is HASH -> CLUSTER BY _file + empty ordering + // delete mode is RANGE -> RANGE DISTRIBUTE BY _file, _pos + empty ordering + // + // UNPARTITIONED ORDERED BY id, data + // ------------------------------------------------------------------------- + // delete mode is NOT SET -> CLUSTER BY _file + LOCALLY ORDER BY id, data + // delete mode is NONE -> unspecified distribution + LOCALLY ORDER BY id, data + // delete mode is HASH -> CLUSTER BY _file + LOCALLY ORDER BY id, data + // delete mode is RANGE -> ORDER BY id, data + // + // PARTITIONED BY date, days(ts) UNORDERED + // ------------------------------------------------------------------------- + // delete mode is NOT SET -> CLUSTER BY date, days(ts) + LOCALLY ORDER BY date, days(ts) + // delete mode is NOT SET (fanout) -> CLUSTER BY date, days(ts) + empty ordering + // delete mode is NONE -> unspecified distribution + LOCALLY ORDERED BY date, days(ts) + // delete mode is NONE (fanout) -> unspecified distribution + empty ordering + // delete mode is HASH -> CLUSTER BY date, days(ts) + LOCALLY ORDER BY date, days(ts) + // delete mode is HASH (fanout) -> CLUSTER BY date, days(ts) + empty ordering + // delete mode is RANGE -> ORDER BY date, days(ts) + // delete mode is RANGE (fanout) -> RANGE DISTRIBUTE BY date, days(ts) + empty ordering + // + // PARTITIONED BY date ORDERED BY id + // ------------------------------------------------------------------------- + // delete mode is NOT SET -> CLUSTER BY date + LOCALLY ORDER BY date, id + // delete mode is NONE -> unspecified distribution + LOCALLY ORDERED BY date, id + // delete mode is HASH -> CLUSTER BY date + LOCALLY ORDER BY date, id + // delete mode is RANGE -> ORDER BY date, id + + @TestTemplate + public void testDefaultCopyOnWriteDeleteUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + checkCopyOnWriteDistributionAndOrdering( + table, DELETE, FILE_CLUSTERED_DISTRIBUTION, EMPTY_ORDERING); + } + + @TestTemplate + public void testNoneCopyOnWriteDeleteUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + checkCopyOnWriteDistributionAndOrdering( + table, DELETE, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @TestTemplate + public void testHashCopyOnWriteDeleteUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + checkCopyOnWriteDistributionAndOrdering( + table, DELETE, FILE_CLUSTERED_DISTRIBUTION, EMPTY_ORDERING); + } + + @TestTemplate + public void testRangeCopyOnWriteDeleteUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + Distribution expectedDistribution = Distributions.ordered(FILE_POSITION_ORDERING); + + checkCopyOnWriteDistributionAndOrdering(table, DELETE, expectedDistribution, EMPTY_ORDERING); + } + + @TestTemplate + public void testDefaultCopyOnWriteDeleteUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, DELETE, FILE_CLUSTERED_DISTRIBUTION, expectedOrdering); + } + + @TestTemplate + public void testNoneCopyOnWriteDeleteUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, DELETE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @TestTemplate + public void testHashCopyOnWriteDeleteUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, DELETE, FILE_CLUSTERED_DISTRIBUTION, expectedOrdering); + } + + @TestTemplate + public void testRangeCopyOnWriteDeleteUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkCopyOnWriteDistributionAndOrdering(table, DELETE, expectedDistribution, expectedOrdering); + } + + @TestTemplate + public void testDefaultCopyOnWriteDeletePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + disableFanoutWriters(table); + + Expression[] expectedClustering = + new Expression[] {Expressions.identity("date"), Expressions.days("ts")}; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, DELETE, expectedDistribution, expectedOrdering); + + enableFanoutWriters(table); + + checkCopyOnWriteDistributionAndOrdering(table, DELETE, expectedDistribution, EMPTY_ORDERING); + } + + @TestTemplate + public void testNoneCopyOnWriteDeletePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + disableFanoutWriters(table); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, DELETE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + + enableFanoutWriters(table); + + checkCopyOnWriteDistributionAndOrdering( + table, DELETE, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @TestTemplate + public void testHashCopyOnWriteDeletePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + disableFanoutWriters(table); + + Expression[] expectedClustering = + new Expression[] {Expressions.identity("date"), Expressions.days("ts")}; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, DELETE, expectedDistribution, expectedOrdering); + + enableFanoutWriters(table); + + checkCopyOnWriteDistributionAndOrdering(table, DELETE, expectedDistribution, EMPTY_ORDERING); + } + + @TestTemplate + public void testRangeCopyOnWriteDeletePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + disableFanoutWriters(table); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkCopyOnWriteDistributionAndOrdering(table, DELETE, expectedDistribution, expectedOrdering); + + enableFanoutWriters(table); + + checkCopyOnWriteDistributionAndOrdering(table, DELETE, expectedDistribution, EMPTY_ORDERING); + } + + @TestTemplate + public void testDefaultCopyOnWriteDeletePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder().desc("id").commit(); + + Expression[] expectedClustering = new Expression[] {Expressions.identity("date")}; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.DESCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, DELETE, expectedDistribution, expectedOrdering); + } + + @TestTemplate + public void testNoneCopyOnWriteDeletePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + table.replaceSortOrder().desc("id").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.DESCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, DELETE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @TestTemplate + public void testHashCopyOnWriteDeletePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + table.replaceSortOrder().asc("id").commit(); + + Expression[] expectedClustering = + new Expression[] {Expressions.identity("date"), Expressions.bucket(8, "data")}; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.bucket(8, "data"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, DELETE, expectedDistribution, expectedOrdering); + } + + @TestTemplate + public void testRangeCopyOnWriteDeletePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + table.replaceSortOrder().asc("id").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkCopyOnWriteDistributionAndOrdering(table, DELETE, expectedDistribution, expectedOrdering); + } + + // ============================================================= + // Distribution and ordering for copy-on-write UPDATE operations + // ============================================================= + // + // UNPARTITIONED UNORDERED + // ------------------------------------------------------------------------- + // update mode is NOT SET -> CLUSTER BY _file + empty ordering + // update mode is NONE -> unspecified distribution + empty ordering + // update mode is HASH -> CLUSTER BY _file + empty ordering + // update mode is RANGE -> RANGE DISTRIBUTE BY _file, _pos + empty ordering + // + // UNPARTITIONED ORDERED BY id, data + // ------------------------------------------------------------------------- + // update mode is NOT SET -> CLUSTER BY _file + LOCALLY ORDER BY id, data + // update mode is NONE -> unspecified distribution + LOCALLY ORDER BY id, data + // update mode is HASH -> CLUSTER BY _file + LOCALLY ORDER BY id, data + // update mode is RANGE -> ORDER BY id, data + // + // PARTITIONED BY date, days(ts) UNORDERED + // ------------------------------------------------------------------------- + // update mode is NOT SET -> CLUSTER BY date, days(ts) + LOCALLY ORDER BY date, days(ts) + // update mode is NOT SET (fanout) -> CLUSTER BY _file + empty ordering + // update mode is NONE -> unspecified distribution + LOCALLY ORDERED BY date, days(ts) + // update mode is NONE (fanout) -> unspecified distribution + empty ordering + // update mode is HASH -> CLUSTER BY date, days(ts) + LOCALLY ORDER BY date, days(ts) + // update mode is HASH (fanout) -> CLUSTER BY date, days(ts) + empty ordering + // update mode is RANGE -> ORDER BY date, days(ts) + // update mode is RANGE (fanout) -> RANGE DISTRIBUTED BY date, days(ts) + empty ordering + // + // PARTITIONED BY date ORDERED BY id + // ------------------------------------------------------------------------- + // update mode is NOT SET -> CLUSTER BY date + LOCALLY ORDER BY date, id + // update mode is NONE -> unspecified distribution + LOCALLY ORDERED BY date, id + // update mode is HASH -> CLUSTER BY date + LOCALLY ORDER BY date, id + // update mode is RANGE -> ORDER BY date, id + + @TestTemplate + public void testDefaultCopyOnWriteUpdateUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + checkCopyOnWriteDistributionAndOrdering( + table, UPDATE, FILE_CLUSTERED_DISTRIBUTION, EMPTY_ORDERING); + } + + @TestTemplate + public void testNoneCopyOnWriteUpdateUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + checkCopyOnWriteDistributionAndOrdering( + table, UPDATE, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @TestTemplate + public void testHashCopyOnWriteUpdateUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + checkCopyOnWriteDistributionAndOrdering( + table, UPDATE, FILE_CLUSTERED_DISTRIBUTION, EMPTY_ORDERING); + } + + @TestTemplate + public void testRangeCopyOnWriteUpdateUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + Distribution expectedDistribution = Distributions.ordered(FILE_POSITION_ORDERING); + + checkCopyOnWriteDistributionAndOrdering(table, UPDATE, expectedDistribution, EMPTY_ORDERING); + } + + @TestTemplate + public void testDefaultCopyOnWriteUpdateUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, UPDATE, FILE_CLUSTERED_DISTRIBUTION, expectedOrdering); + } + + @TestTemplate + public void testNoneCopyOnWriteUpdateUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, UPDATE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @TestTemplate + public void testHashCopyOnWriteUpdateUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, UPDATE, FILE_CLUSTERED_DISTRIBUTION, expectedOrdering); + } + + @TestTemplate + public void testRangeCopyOnWriteUpdateUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkCopyOnWriteDistributionAndOrdering(table, UPDATE, expectedDistribution, expectedOrdering); + } + + @TestTemplate + public void testDefaultCopyOnWriteUpdatePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + disableFanoutWriters(table); + + Expression[] expectedClustering = + new Expression[] {Expressions.identity("date"), Expressions.days("ts")}; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, UPDATE, expectedDistribution, expectedOrdering); + + enableFanoutWriters(table); + + checkCopyOnWriteDistributionAndOrdering(table, UPDATE, expectedDistribution, EMPTY_ORDERING); + } + + @TestTemplate + public void testNoneCopyOnWriteUpdatePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + disableFanoutWriters(table); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, UPDATE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + + enableFanoutWriters(table); + + checkCopyOnWriteDistributionAndOrdering( + table, UPDATE, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @TestTemplate + public void testHashCopyOnWriteUpdatePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + disableFanoutWriters(table); + + Expression[] expectedClustering = + new Expression[] {Expressions.identity("date"), Expressions.days("ts")}; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, UPDATE, expectedDistribution, expectedOrdering); + + enableFanoutWriters(table); + + checkCopyOnWriteDistributionAndOrdering(table, UPDATE, expectedDistribution, EMPTY_ORDERING); + } + + @TestTemplate + public void testRangeCopyOnWriteUpdatePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + disableFanoutWriters(table); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkCopyOnWriteDistributionAndOrdering(table, UPDATE, expectedDistribution, expectedOrdering); + + enableFanoutWriters(table); + + checkCopyOnWriteDistributionAndOrdering(table, UPDATE, expectedDistribution, EMPTY_ORDERING); + } + + @TestTemplate + public void testDefaultCopyOnWriteUpdatePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder().desc("id").commit(); + + Expression[] expectedClustering = new Expression[] {Expressions.identity("date")}; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.DESCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, UPDATE, expectedDistribution, expectedOrdering); + } + + @TestTemplate + public void testNoneCopyOnWriteUpdatePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + table.replaceSortOrder().desc("id").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.DESCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, UPDATE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @TestTemplate + public void testHashCopyOnWriteUpdatePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + table.replaceSortOrder().asc("id").commit(); + + Expression[] expectedClustering = + new Expression[] {Expressions.identity("date"), Expressions.bucket(8, "data")}; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.bucket(8, "data"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, UPDATE, expectedDistribution, expectedOrdering); + } + + @TestTemplate + public void testRangeCopyOnWriteUpdatePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + table.replaceSortOrder().asc("id").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkCopyOnWriteDistributionAndOrdering(table, UPDATE, expectedDistribution, expectedOrdering); + } + + // ============================================================= + // Distribution and ordering for copy-on-write MERGE operations + // ============================================================= + // + // UNPARTITIONED UNORDERED + // ------------------------------------------------------------------------- + // merge mode is NOT SET -> use write distribution and ordering + // merge mode is NONE -> unspecified distribution + empty ordering + // merge mode is HASH -> unspecified distribution + empty ordering + // merge mode is RANGE -> unspecified distribution + empty ordering + // + // UNPARTITIONED ORDERED BY id, data + // ------------------------------------------------------------------------- + // merge mode is NOT SET -> use write distribution and ordering + // merge mode is NONE -> unspecified distribution + LOCALLY ORDER BY id, data + // merge mode is HASH -> unspecified distribution + LOCALLY ORDER BY id, data + // merge mode is RANGE -> ORDER BY id, data + // + // PARTITIONED BY date, days(ts) UNORDERED + // ------------------------------------------------------------------------- + // merge mode is NOT SET -> CLUSTER BY date, days(ts) + LOCALLY ORDER BY date, days(ts) + // merge mode is NOT SET (fanout) -> CLUSTER BY date, days(ts) + empty ordering + // merge mode is NONE -> unspecified distribution + LOCALLY ORDERED BY date, days(ts) + // merge mode is NONE (fanout) -> unspecified distribution + empty ordering + // merge mode is HASH -> CLUSTER BY date, days(ts) + LOCALLY ORDER BY date, days(ts) + // merge mode is HASH (fanout) -> CLUSTER BY date, days(ts) + empty ordering + // merge mode is RANGE -> ORDER BY date, days(ts) + // merge mode is RANGE (fanout) -> RANGE DISTRIBUTE BY date, days(ts) + empty ordering + // + // PARTITIONED BY date ORDERED BY id + // ------------------------------------------------------------------------- + // merge mode is NOT SET -> CLUSTER BY date + LOCALLY ORDER BY date, id + // merge mode is NONE -> unspecified distribution + LOCALLY ORDERED BY date, id + // merge mode is HASH -> CLUSTER BY date + LOCALLY ORDER BY date, id + // merge mode is RANGE -> ORDERED BY date, id + + @TestTemplate + public void testDefaultCopyOnWriteMergeUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @TestTemplate + public void testNoneCopyOnWriteMergeUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @TestTemplate + public void testHashCopyOnWriteMergeUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @TestTemplate + public void testRangeCopyOnWriteMergeUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @TestTemplate + public void testDefaultCopyOnWriteMergeUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @TestTemplate + public void testNoneCopyOnWriteMergeUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, MERGE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @TestTemplate + public void testHashCopyOnWriteMergeUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, MERGE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @TestTemplate + public void testRangeCopyOnWriteMergeUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @TestTemplate + public void testDefaultCopyOnWriteMergePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + disableFanoutWriters(table); + + Expression[] expectedClustering = + new Expression[] {Expressions.identity("date"), Expressions.days("ts")}; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + + enableFanoutWriters(table); + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, expectedDistribution, EMPTY_ORDERING); + } + + @TestTemplate + public void testNoneCopyOnWriteMergePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + disableFanoutWriters(table); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, MERGE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + + enableFanoutWriters(table); + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @TestTemplate + public void testHashCopyOnWriteMergePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + disableFanoutWriters(table); + + Expression[] expectedClustering = + new Expression[] {Expressions.identity("date"), Expressions.days("ts")}; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + + enableFanoutWriters(table); + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, expectedDistribution, EMPTY_ORDERING); + } + + @TestTemplate + public void testRangeCopyOnWriteMergePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + disableFanoutWriters(table); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + + enableFanoutWriters(table); + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, expectedDistribution, EMPTY_ORDERING); + } + + @TestTemplate + public void testDefaultCopyOnWriteMergePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder().desc("id").commit(); + + Expression[] expectedClustering = new Expression[] {Expressions.identity("date")}; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.DESCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @TestTemplate + public void testNoneCopyOnWriteMergePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + table.replaceSortOrder().desc("id").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.DESCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, MERGE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @TestTemplate + public void testHashCopyOnWriteMergePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + table.replaceSortOrder().asc("id").commit(); + + Expression[] expectedClustering = + new Expression[] {Expressions.identity("date"), Expressions.bucket(8, "data")}; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.bucket(8, "data"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @TestTemplate + public void testRangeCopyOnWriteMergePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + table.replaceSortOrder().asc("id").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + // =================================================================================== + // Distribution and ordering for merge-on-read DELETE operations with position deletes + // =================================================================================== + // + // UNPARTITIONED (ORDERED & UNORDERED) + // ------------------------------------------------------------------------- + // delete mode is NOT SET -> CLUSTER BY _spec_id, _partition, _file + + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos + // delete mode is NOT SET (fanout) -> CLUSTER BY _spec_id, _partition, _file + empty ordering + // delete mode is NONE -> unspecified distribution + + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos + // delete mode is NOT SET (fanout) -> CLUSTER BY _spec_id, _partition, _file + empty ordering + // delete mode is HASH -> CLUSTER BY _spec_id, _partition, _file + + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos + // delete mode is HASH (fanout) -> CLUSTER BY _spec_id, _partition, _file + empty ordering + // delete mode is RANGE -> RANGE DISTRIBUTE BY _spec_id, _partition, _file + + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos + // delete mode is RANGE (fanout) -> RANGE DISTRIBUTE BY _spec_id, _partition, _file + + // empty ordering + // + // PARTITIONED BY date, days(ts) (ORDERED & UNORDERED) + // ------------------------------------------------------------------------- + // delete mode is NOT SET -> CLUSTER BY _spec_id, _partition + + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos + // delete mode is NOT SET (fanout) -> CLUSTER BY _spec_id, _partition + empty ordering + // delete mode is NONE -> unspecified distribution + + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos + // delete mode is NONE (fanout) -> unspecified distribution + empty ordering + // delete mode is HASH -> CLUSTER BY _spec_id, _partition + + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos + // delete mode is HASH (fanout) -> CLUSTER BY _spec_id, _partition + empty ordering + // delete mode is RANGE -> RANGE DISTRIBUTE BY _spec_id, _partition + + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos + // delete mode is RANGE (fanout) -> RANGE DISTRIBUTE BY _spec_id, _partition + empty ordering + + @TestTemplate + public void testDefaultPositionDeltaDeleteUnpartitionedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + disableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering( + table, + DELETE, + SPEC_ID_PARTITION_FILE_CLUSTERED_DISTRIBUTION, + SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + + enableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering( + table, DELETE, SPEC_ID_PARTITION_FILE_CLUSTERED_DISTRIBUTION, EMPTY_ORDERING); + } + + @TestTemplate + public void testNonePositionDeltaDeleteUnpartitionedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + disableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering( + table, DELETE, UNSPECIFIED_DISTRIBUTION, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + + enableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering( + table, DELETE, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @TestTemplate + public void testHashPositionDeltaDeleteUnpartitionedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + disableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering( + table, + DELETE, + SPEC_ID_PARTITION_FILE_CLUSTERED_DISTRIBUTION, + SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + + enableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering( + table, DELETE, SPEC_ID_PARTITION_FILE_CLUSTERED_DISTRIBUTION, EMPTY_ORDERING); + } + + @TestTemplate + public void testRangePositionDeltaDeleteUnpartitionedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + disableFanoutWriters(table); + + Distribution expectedDistribution = Distributions.ordered(SPEC_ID_PARTITION_FILE_ORDERING); + + checkPositionDeltaDistributionAndOrdering( + table, DELETE, expectedDistribution, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + + enableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering(table, DELETE, expectedDistribution, EMPTY_ORDERING); + } + + @TestTemplate + public void testDefaultPositionDeltaDeletePartitionedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + disableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering( + table, + DELETE, + SPEC_ID_PARTITION_CLUSTERED_DISTRIBUTION, + SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + + enableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering( + table, DELETE, SPEC_ID_PARTITION_CLUSTERED_DISTRIBUTION, EMPTY_ORDERING); + } + + @TestTemplate + public void testNonePositionDeltaDeletePartitionedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + disableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering( + table, DELETE, UNSPECIFIED_DISTRIBUTION, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + + enableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering( + table, DELETE, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @TestTemplate + public void testHashPositionDeltaDeletePartitionedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + disableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering( + table, + DELETE, + SPEC_ID_PARTITION_CLUSTERED_DISTRIBUTION, + SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + + enableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering( + table, DELETE, SPEC_ID_PARTITION_CLUSTERED_DISTRIBUTION, EMPTY_ORDERING); + } + + @TestTemplate + public void testRangePositionDeltaDeletePartitionedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + disableFanoutWriters(table); + + Distribution expectedDistribution = Distributions.ordered(SPEC_ID_PARTITION_ORDERING); + + checkPositionDeltaDistributionAndOrdering( + table, DELETE, expectedDistribution, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + + enableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering(table, DELETE, expectedDistribution, EMPTY_ORDERING); + } + + // =================================================================================== + // Distribution and ordering for merge-on-read UPDATE operations with position deletes + // =================================================================================== + // + // IMPORTANT: updates are represented as delete and insert + // IMPORTANT: metadata columns like _spec_id and _partition are NULL for new insert rows + // + // UNPARTITIONED UNORDERED + // ------------------------------------------------------------------------- + // update mode is NOT SET -> CLUSTER BY _spec_id, _partition, _file + + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos + // update mode is NOT SET (fanout) -> CLUSTER BY _spec_id, _partition, _file + empty ordering + // update mode is NONE -> unspecified distribution + + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos + // update mode is NONE (fanout) -> unspecified distribution + empty ordering + // update mode is HASH -> CLUSTER BY _spec_id, _partition, _file + + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos + // update mode is HASH (fanout) -> CLUSTER BY _spec_id, _partition, _file + empty ordering + // update mode is RANGE -> RANGE DISTRIBUTE BY _spec_id, _partition, _file + + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos + // update mode is RANGE (fanout) -> RANGE DISTRIBUTE BY _spec_id, _partition, _file + empty + // ordering + // + // UNPARTITIONED ORDERED BY id, data + // ------------------------------------------------------------------------- + // update mode is NOT SET -> CLUSTER BY _spec_id, _partition, _file + + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos, id, data + // update mode is NONE -> unspecified distribution + + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos, id, data + // update mode is HASH -> CLUSTER BY _spec_id, _partition, _file + + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos, id, data + // update mode is RANGE -> RANGE DISTRIBUTE BY _spec_id, _partition, _file, id, data + + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos, id, data + // + // PARTITIONED BY date, days(ts) UNORDERED + // ------------------------------------------------------------------------- + // update mode is NOT SET -> CLUSTER BY _spec_id, _partition, date, days(ts) + + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos, date, days(ts) + // update mode is NOT SET (fanout) -> CLUSTER BY _spec_id, _partition, date, days(ts) + + // empty ordering + // update mode is NONE -> unspecified distribution + + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos, date, days(ts) + // update mode is NONE (fanout) -> unspecified distribution + empty ordering + // update mode is HASH -> CLUSTER BY _spec_id, _partition, date, days(ts) + + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos, date, days(ts) + // update mode is HASH (fanout) -> CLUSTER BY _spec_id, _partition, date, days(ts) + + // empty ordering + // update mode is RANGE -> RANGE DISTRIBUTE BY _spec_id, _partition, date, days(ts) + + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos, date, days(ts) + // update mode is RANGE (fanout) -> RANGE DISTRIBUTE BY _spec_id, _partition, date, days(ts) + + // empty ordering + // + // PARTITIONED BY date ORDERED BY id + // ------------------------------------------------------------------------- + // update mode is NOT SET -> CLUSTER BY _spec_id, _partition, date + + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos, date, id + // update mode is NONE -> unspecified distribution + + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos, date, id + // update mode is HASH -> CLUSTER BY _spec_id, _partition, date + + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos, date, id + // update mode is RANGE -> RANGE DISTRIBUTE BY _spec_id, _partition, date, id + + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos, date, id + + @TestTemplate + public void testDefaultPositionDeltaUpdateUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + disableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering( + table, + UPDATE, + SPEC_ID_PARTITION_FILE_CLUSTERED_DISTRIBUTION, + SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + + enableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering( + table, UPDATE, SPEC_ID_PARTITION_FILE_CLUSTERED_DISTRIBUTION, EMPTY_ORDERING); + } + + @TestTemplate + public void testNonePositionDeltaUpdateUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + disableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering( + table, UPDATE, UNSPECIFIED_DISTRIBUTION, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + + enableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering( + table, UPDATE, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @TestTemplate + public void testHashPositionDeltaUpdateUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + disableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering( + table, + UPDATE, + SPEC_ID_PARTITION_FILE_CLUSTERED_DISTRIBUTION, + SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + + enableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering( + table, UPDATE, SPEC_ID_PARTITION_FILE_CLUSTERED_DISTRIBUTION, EMPTY_ORDERING); + } + + @TestTemplate + public void testRangePositionDeltaUpdateUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + disableFanoutWriters(table); + + Distribution expectedDistribution = Distributions.ordered(SPEC_ID_PARTITION_FILE_ORDERING); + + checkPositionDeltaDistributionAndOrdering( + table, UPDATE, expectedDistribution, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + + enableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering(table, UPDATE, expectedDistribution, EMPTY_ORDERING); + } + + @TestTemplate + public void testDefaultPositionDeltaUpdateUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering( + table, UPDATE, SPEC_ID_PARTITION_FILE_CLUSTERED_DISTRIBUTION, expectedOrdering); + } + + @TestTemplate + public void testNonePositionDeltaUpdateUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering( + table, UPDATE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @TestTemplate + public void testHashPositionDeltaUpdateUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering( + table, UPDATE, SPEC_ID_PARTITION_FILE_CLUSTERED_DISTRIBUTION, expectedOrdering); + } + + @TestTemplate + public void testRangePositionDeltaUpdateUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + SortOrder[] expectedDistributionOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + Distribution expectedDistribution = Distributions.ordered(expectedDistributionOrdering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering( + table, UPDATE, expectedDistribution, expectedOrdering); + } + + @TestTemplate + public void testDefaultPositionDeltaUpdatePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + disableFanoutWriters(table); + + Expression[] expectedClustering = + new Expression[] { + Expressions.column(MetadataColumns.SPEC_ID.name()), + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), + Expressions.identity("date"), + Expressions.bucket(8, "data") + }; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.bucket(8, "data"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering( + table, UPDATE, expectedDistribution, expectedOrdering); + + enableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering(table, UPDATE, expectedDistribution, EMPTY_ORDERING); + } + + @TestTemplate + public void testNonePositionDeltaUpdatePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + disableFanoutWriters(table); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.bucket(8, "data"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering( + table, UPDATE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + + enableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering( + table, UPDATE, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @TestTemplate + public void testHashPositionDeltaUpdatePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + disableFanoutWriters(table); + + Expression[] expectedClustering = + new Expression[] { + Expressions.column(MetadataColumns.SPEC_ID.name()), + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), + Expressions.identity("date"), + Expressions.bucket(8, "data") + }; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.bucket(8, "data"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering( + table, UPDATE, expectedDistribution, expectedOrdering); + + enableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering(table, UPDATE, expectedDistribution, EMPTY_ORDERING); + } + + @TestTemplate + public void testRangePositionDeltaUpdatePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + disableFanoutWriters(table); + + SortOrder[] expectedDistributionOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.bucket(8, "data"), SortDirection.ASCENDING) + }; + Distribution expectedDistribution = Distributions.ordered(expectedDistributionOrdering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.bucket(8, "data"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering( + table, UPDATE, expectedDistribution, expectedOrdering); + + enableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering(table, UPDATE, expectedDistribution, EMPTY_ORDERING); + } + + @TestTemplate + public void testDefaultPositionDeltaUpdatePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder().asc("id").commit(); + + Expression[] expectedClustering = + new Expression[] { + Expressions.column(MetadataColumns.SPEC_ID.name()), + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), + Expressions.identity("date"), + Expressions.bucket(8, "data") + }; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.bucket(8, "data"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering( + table, UPDATE, expectedDistribution, expectedOrdering); + } + + @TestTemplate + public void testNonePositionDeltaUpdatePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder().asc("id").commit(); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.bucket(8, "data"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering( + table, UPDATE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @TestTemplate + public void testHashPositionDeltaUpdatePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder().asc("id").commit(); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + Expression[] expectedClustering = + new Expression[] { + Expressions.column(MetadataColumns.SPEC_ID.name()), + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), + Expressions.identity("date"), + Expressions.bucket(8, "data") + }; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.bucket(8, "data"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering( + table, UPDATE, expectedDistribution, expectedOrdering); + } + + @TestTemplate + public void testRangePositionDeltaUpdatePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder().asc("id").commit(); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + SortOrder[] expectedDistributionOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.bucket(8, "data"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + Distribution expectedDistribution = Distributions.ordered(expectedDistributionOrdering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.bucket(8, "data"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering( + table, UPDATE, expectedDistribution, expectedOrdering); + } + + // ================================================================================== + // Distribution and ordering for merge-on-read MERGE operations with position deletes + // ================================================================================== + // + // IMPORTANT: updates are represented as delete and insert + // IMPORTANT: metadata columns like _spec_id and _partition are NULL for new rows + // + // UNPARTITIONED UNORDERED + // ------------------------------------------------------------------------- + // merge mode is NOT SET -> CLUSTER BY _spec_id, _partition, _file + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos + // merge mode is NOT SET (fanout) -> CLUSTER BY _spec_id, _partition, _file + empty ordering + // merge mode is NONE -> unspecified distribution + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos + // merge mode is NONE (fanout) -> unspecified distribution + empty ordering + // merge mode is HASH -> CLUSTER BY _spec_id, _partition, _file + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos + // merge mode is HASH (fanout) -> CLUSTER BY _spec_id, _partition, _file + + // empty ordering + // merge mode is RANGE -> RANGE DISTRIBUTE BY _spec_id, _partition, _file + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos + // merge mode is RANGE (fanout) -> RANGE DISTRIBUTE BY _spec_id, _partition, _file + + // empty ordering + // + // UNPARTITIONED ORDERED BY id, data + // ------------------------------------------------------------------------- + // merge mode is NOT SET -> CLUSTER BY _spec_id, _partition, _file + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos, id, data + // merge mode is NONE -> unspecified distribution + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos, id, data + // merge mode is HASH -> CLUSTER BY _spec_id, _partition, _file + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos, id, data + // merge mode is RANGE -> RANGE DISTRIBUTE BY _spec_id, _partition, _file, id, data + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos, id, data + // + // PARTITIONED BY date, days(ts) UNORDERED + // ------------------------------------------------------------------------- + // merge mode is NOT SET -> CLUSTER BY _spec_id, _partition, date, days(ts) + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos, date, days(ts) + // merge mode is NOT SET (fanout) -> CLUSTER BY _spec_id, _partition, date, days(ts) + + // empty ordering + // merge mode is NONE -> unspecified distribution + + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos, date, days(ts) + // merge mode is NONE (fanout) -> unspecified distribution + empty ordering + // merge mode is HASH -> CLUSTER BY _spec_id, _partition, date, days(ts) + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos, date, days(ts) + // merge mode is HASH (fanout) -> CLUSTER BY _spec_id, _partition, date, days(ts) + + // empty ordering + // merge mode is RANGE -> RANGE DISTRIBUTE BY _spec_id, _partition, date, days(ts) + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos, date, days(ts) + // merge mode is RANGE (fanout) -> RANGE DISTRIBUTE BY _spec_id, _partition, date, days(ts) + + // empty ordering + // + // PARTITIONED BY date ORDERED BY id + // ------------------------------------------------------------------------- + // merge mode is NOT SET -> CLUSTER BY _spec_id, _partition, date + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos, date, id + // merge mode is NONE -> unspecified distribution + + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos, date, id + // merge mode is HASH -> CLUSTER BY _spec_id, _partition, date + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos, date, id + // merge mode is RANGE -> RANGE DISTRIBUTE BY _spec_id, _partition, date, id + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos, date, id + + @TestTemplate + public void testDefaultPositionDeltaMergeUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + disableFanoutWriters(table); + + Expression[] expectedClustering = + new Expression[] { + Expressions.column(MetadataColumns.SPEC_ID.name()), + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), + Expressions.column(MetadataColumns.FILE_PATH.name()) + }; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + checkPositionDeltaDistributionAndOrdering( + table, MERGE, expectedDistribution, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + + enableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering(table, MERGE, expectedDistribution, EMPTY_ORDERING); + } + + @TestTemplate + public void testNonePositionDeltaMergeUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + disableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering( + table, MERGE, UNSPECIFIED_DISTRIBUTION, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + + enableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering( + table, MERGE, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @TestTemplate + public void testHashPositionDeltaMergeUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + disableFanoutWriters(table); + + Expression[] expectedClustering = + new Expression[] { + Expressions.column(MetadataColumns.SPEC_ID.name()), + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), + Expressions.column(MetadataColumns.FILE_PATH.name()) + }; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + checkPositionDeltaDistributionAndOrdering( + table, MERGE, expectedDistribution, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + + enableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering(table, MERGE, expectedDistribution, EMPTY_ORDERING); + } + + @TestTemplate + public void testRangePositionDeltaMergeUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + disableFanoutWriters(table); + + SortOrder[] expectedDistributionOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING) + }; + Distribution expectedDistribution = Distributions.ordered(expectedDistributionOrdering); + + checkPositionDeltaDistributionAndOrdering( + table, MERGE, expectedDistribution, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + + enableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering(table, MERGE, expectedDistribution, EMPTY_ORDERING); + } + + @TestTemplate + public void testDefaultPositionDeltaMergeUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + Expression[] expectedClustering = + new Expression[] { + Expressions.column(MetadataColumns.SPEC_ID.name()), + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), + Expressions.column(MetadataColumns.FILE_PATH.name()) + }; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @TestTemplate + public void testNonePositionDeltaMergeUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering( + table, MERGE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @TestTemplate + public void testHashPositionDeltaMergeUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + Expression[] expectedClustering = + new Expression[] { + Expressions.column(MetadataColumns.SPEC_ID.name()), + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), + Expressions.column(MetadataColumns.FILE_PATH.name()) + }; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @TestTemplate + public void testRangePositionDeltaMergeUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedDistributionOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + Distribution expectedDistribution = Distributions.ordered(expectedDistributionOrdering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @TestTemplate + public void testDefaultPositionDeltaMergePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + disableFanoutWriters(table); + + Expression[] expectedClustering = + new Expression[] { + Expressions.column(MetadataColumns.SPEC_ID.name()), + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), + Expressions.identity("date"), + Expressions.days("ts") + }; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + + enableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering(table, MERGE, expectedDistribution, EMPTY_ORDERING); + } + + @TestTemplate + public void testNonePositionDeltaMergePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + disableFanoutWriters(table); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering( + table, MERGE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + + enableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering( + table, MERGE, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @TestTemplate + public void testHashPositionDeltaMergePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + disableFanoutWriters(table); + + Expression[] expectedClustering = + new Expression[] { + Expressions.column(MetadataColumns.SPEC_ID.name()), + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), + Expressions.identity("date"), + Expressions.days("ts") + }; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + + enableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering(table, MERGE, expectedDistribution, EMPTY_ORDERING); + } + + @TestTemplate + public void testRangePositionDeltaMergePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + disableFanoutWriters(table); + + SortOrder[] expectedDistributionOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + Distribution expectedDistribution = Distributions.ordered(expectedDistributionOrdering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + + enableFanoutWriters(table); + + checkPositionDeltaDistributionAndOrdering(table, MERGE, expectedDistribution, EMPTY_ORDERING); + } + + @TestTemplate + public void testNonePositionDeltaMergePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + table.replaceSortOrder().desc("id").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.DESCENDING) + }; + + checkPositionDeltaDistributionAndOrdering( + table, MERGE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @TestTemplate + public void testDefaultPositionDeltaMergePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder().asc("id").commit(); + + Expression[] expectedClustering = + new Expression[] { + Expressions.column(MetadataColumns.SPEC_ID.name()), + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), + Expressions.identity("date"), + Expressions.bucket(8, "data") + }; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.bucket(8, "data"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @TestTemplate + public void testHashPositionDeltaMergePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + table.replaceSortOrder().asc("id").commit(); + + Expression[] expectedClustering = + new Expression[] { + Expressions.column(MetadataColumns.SPEC_ID.name()), + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), + Expressions.identity("date"), + Expressions.bucket(8, "data") + }; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.bucket(8, "data"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @TestTemplate + public void testRangePositionDeltaMergePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + table.replaceSortOrder().asc("id").commit(); + + SortOrder[] expectedDistributionOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + Distribution expectedDistribution = Distributions.ordered(expectedDistributionOrdering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + private void checkWriteDistributionAndOrdering( + Table table, Distribution expectedDistribution, SortOrder[] expectedOrdering) { + SparkWriteConf writeConf = new SparkWriteConf(spark, table, ImmutableMap.of()); + + SparkWriteRequirements requirements = writeConf.writeRequirements(); + + Distribution distribution = requirements.distribution(); + assertThat(distribution).as("Distribution must match").isEqualTo(expectedDistribution); + + SortOrder[] ordering = requirements.ordering(); + assertThat(ordering).as("Ordering must match").isEqualTo(expectedOrdering); + } + + private void checkCopyOnWriteDistributionAndOrdering( + Table table, + Command command, + Distribution expectedDistribution, + SortOrder[] expectedOrdering) { + SparkWriteConf writeConf = new SparkWriteConf(spark, table, ImmutableMap.of()); + + SparkWriteRequirements requirements = writeConf.copyOnWriteRequirements(command); + + Distribution distribution = requirements.distribution(); + assertThat(distribution).as("Distribution must match").isEqualTo(expectedDistribution); + + SortOrder[] ordering = requirements.ordering(); + assertThat(ordering).as("Ordering must match").isEqualTo(expectedOrdering); + } + + private void checkPositionDeltaDistributionAndOrdering( + Table table, + Command command, + Distribution expectedDistribution, + SortOrder[] expectedOrdering) { + SparkWriteConf writeConf = new SparkWriteConf(spark, table, ImmutableMap.of()); + + SparkWriteRequirements requirements = writeConf.positionDeltaRequirements(command); + + Distribution distribution = requirements.distribution(); + assertThat(distribution).as("Distribution must match").isEqualTo(expectedDistribution); + + SortOrder[] ordering = requirements.ordering(); + assertThat(ordering).as("Ordering must match").isEqualTo(expectedOrdering); + } + + private void disableFanoutWriters(Table table) { + table.updateProperties().set(SPARK_WRITE_PARTITIONED_FANOUT_ENABLED, "false").commit(); + } + + private void enableFanoutWriters(Table table) { + table.updateProperties().set(SPARK_WRITE_PARTITIONED_FANOUT_ENABLED, "true").commit(); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkExecutorCache.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkExecutorCache.java new file mode 100644 index 000000000000..d9d7b78799ba --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkExecutorCache.java @@ -0,0 +1,510 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.apache.iceberg.RowLevelOperationMode.COPY_ON_WRITE; +import static org.apache.iceberg.RowLevelOperationMode.MERGE_ON_READ; +import static org.assertj.core.api.Assertions.assertThat; + +import com.github.benmanes.caffeine.cache.Cache; +import java.io.File; +import java.io.IOException; +import java.lang.reflect.Field; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import org.apache.iceberg.CatalogProperties; +import org.apache.iceberg.ContentScanTask; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Files; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.data.FileHelpers; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.exceptions.RuntimeIOException; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.io.SeekableInputStream; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.SparkExecutorCache.CacheValue; +import org.apache.iceberg.spark.SparkExecutorCache.Conf; +import org.apache.iceberg.util.CharSequenceSet; +import org.apache.iceberg.util.Pair; +import org.apache.spark.SparkEnv; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.storage.memory.MemoryStore; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestSparkExecutorCache extends TestBaseWithCatalog { + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + protected static Object[][] parameters() { + return new Object[][] { + { + "testhive", + SparkCatalog.class.getName(), + ImmutableMap.of( + "type", + "hive", + CatalogProperties.FILE_IO_IMPL, + CustomFileIO.class.getName(), + "default-namespace", + "default") + }, + }; + } + + private static final String UPDATES_VIEW_NAME = "updates"; + private static final AtomicInteger JOB_COUNTER = new AtomicInteger(); + private static final Map INPUT_FILES = + Collections.synchronizedMap(Maps.newHashMap()); + + private String targetTableName; + private TableIdentifier targetTableIdent; + + @BeforeEach + public void configureTargetTableName() { + String name = "target_exec_cache_" + JOB_COUNTER.incrementAndGet(); + this.targetTableName = tableName(name); + this.targetTableIdent = TableIdentifier.of(Namespace.of("default"), name); + } + + @AfterEach + public void releaseResources() { + sql("DROP TABLE IF EXISTS %s", targetTableName); + sql("DROP TABLE IF EXISTS %s", UPDATES_VIEW_NAME); + INPUT_FILES.clear(); + } + + @TestTemplate + public void testCacheValueWeightOverflow() { + CacheValue cacheValue = new CacheValue("v", Integer.MAX_VALUE + 1L); + assertThat(cacheValue.weight()).isEqualTo(Integer.MAX_VALUE); + } + + @TestTemplate + public void testCacheEnabledConfig() { + withSQLConf( + ImmutableMap.of(SparkSQLProperties.EXECUTOR_CACHE_ENABLED, "true"), + () -> { + Conf conf = new Conf(); + assertThat(conf.cacheEnabled()).isTrue(); + }); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.EXECUTOR_CACHE_ENABLED, "false"), + () -> { + Conf conf = new Conf(); + assertThat(conf.cacheEnabled()).isFalse(); + }); + } + + @TestTemplate + public void testTimeoutConfig() { + withSQLConf( + ImmutableMap.of(SparkSQLProperties.EXECUTOR_CACHE_TIMEOUT, "10s"), + () -> { + Conf conf = new Conf(); + assertThat(conf.timeout()).hasSeconds(10); + }); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.EXECUTOR_CACHE_TIMEOUT, "2m"), + () -> { + Conf conf = new Conf(); + assertThat(conf.timeout()).hasMinutes(2); + }); + } + + @TestTemplate + public void testMaxEntrySizeConfig() { + withSQLConf( + ImmutableMap.of(SparkSQLProperties.EXECUTOR_CACHE_MAX_ENTRY_SIZE, "128"), + () -> { + Conf conf = new Conf(); + assertThat(conf.maxEntrySize()).isEqualTo(128L); + }); + } + + @TestTemplate + public void testMaxTotalSizeConfig() { + withSQLConf( + ImmutableMap.of(SparkSQLProperties.EXECUTOR_CACHE_MAX_TOTAL_SIZE, "512"), + () -> { + Conf conf = new Conf(); + assertThat(conf.maxTotalSize()).isEqualTo(512L); + }); + } + + @TestTemplate + public void testConcurrentAccess() throws InterruptedException { + SparkExecutorCache cache = SparkExecutorCache.getOrCreate(); + + String table1 = "table1"; + String table2 = "table2"; + + Set loadedInternalKeys = Sets.newHashSet(); + + String key1 = "key1"; + String key2 = "key2"; + + long valueSize = 100L; + + int threadCount = 10; + ExecutorService executorService = Executors.newFixedThreadPool(threadCount); + + for (int threadNumber = 0; threadNumber < threadCount; threadNumber++) { + String group = threadNumber % 2 == 0 ? table1 : table2; + executorService.submit( + () -> { + for (int batch = 0; batch < 3; batch++) { + cache.getOrLoad( + group, + key1, + () -> { + String internalKey = toInternalKey(group, key1); + synchronized (loadedInternalKeys) { + // verify only one load was done for this key + assertThat(loadedInternalKeys.contains(internalKey)).isFalse(); + loadedInternalKeys.add(internalKey); + } + return "value1"; + }, + valueSize); + + cache.getOrLoad( + group, + key2, + () -> { + String internalKey = toInternalKey(group, key2); + synchronized (loadedInternalKeys) { + // verify only one load was done for this key + assertThat(loadedInternalKeys.contains(internalKey)).isFalse(); + loadedInternalKeys.add(internalKey); + } + return "value2"; + }, + valueSize); + } + }); + } + + executorService.shutdown(); + assertThat(executorService.awaitTermination(1, TimeUnit.MINUTES)).isTrue(); + + cache.invalidate(table1); + cache.invalidate(table2); + + // all keys must be invalidated + Cache state = fetchInternalCacheState(); + Set liveKeys = state.asMap().keySet(); + assertThat(liveKeys).noneMatch(key -> key.startsWith(table1) || key.startsWith(table2)); + } + + @TestTemplate + public void testCopyOnWriteDelete() throws Exception { + checkDelete(COPY_ON_WRITE); + } + + @TestTemplate + public void testMergeOnReadDelete() throws Exception { + checkDelete(MERGE_ON_READ); + } + + private void checkDelete(RowLevelOperationMode mode) throws Exception { + List deleteFiles = createAndInitTable(TableProperties.DELETE_MODE, mode); + + sql("DELETE FROM %s WHERE id = 1 OR id = 4", targetTableName); + + // there are 2 data files and 2 delete files that apply to both of them + // in CoW, the target table will be scanned 2 times (main query + runtime filter) + // the runtime filter may invalidate the cache so check at least some requests were hits + // in MoR, the target table will be scanned only once + // so each delete file must be opened once + int maxRequestCount = mode == COPY_ON_WRITE ? 3 : 1; + assertThat(deleteFiles).allMatch(deleteFile -> streamCount(deleteFile) <= maxRequestCount); + + // verify the final set of records is correct + assertEquals( + "Should have expected rows", + ImmutableList.of(), + sql("SELECT * FROM %s ORDER BY id ASC", targetTableName)); + } + + @TestTemplate + public void testCopyOnWriteUpdate() throws Exception { + checkUpdate(COPY_ON_WRITE); + } + + @TestTemplate + public void testMergeOnReadUpdate() throws Exception { + checkUpdate(MERGE_ON_READ); + } + + private void checkUpdate(RowLevelOperationMode mode) throws Exception { + List deleteFiles = createAndInitTable(TableProperties.UPDATE_MODE, mode); + + Dataset updateDS = spark.createDataset(ImmutableList.of(1, 4), Encoders.INT()); + updateDS.createOrReplaceTempView(UPDATES_VIEW_NAME); + + sql("UPDATE %s SET id = -1 WHERE id IN (SELECT * FROM %s)", targetTableName, UPDATES_VIEW_NAME); + + // there are 2 data files and 2 delete files that apply to both of them + // in CoW, the target table will be scanned 3 times (2 in main query + runtime filter) + // the runtime filter may invalidate the cache so check at least some requests were hits + // in MoR, the target table will be scanned only once + // so each delete file must be opened once + int maxRequestCount = mode == COPY_ON_WRITE ? 5 : 1; + assertThat(deleteFiles).allMatch(deleteFile -> streamCount(deleteFile) <= maxRequestCount); + + // verify the final set of records is correct + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(-1, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC", targetTableName)); + } + + @TestTemplate + public void testCopyOnWriteMerge() throws Exception { + checkMerge(COPY_ON_WRITE); + } + + @TestTemplate + public void testMergeOnReadMerge() throws Exception { + checkMerge(MERGE_ON_READ); + } + + private void checkMerge(RowLevelOperationMode mode) throws Exception { + List deleteFiles = createAndInitTable(TableProperties.MERGE_MODE, mode); + + Dataset updateDS = spark.createDataset(ImmutableList.of(1, 4), Encoders.INT()); + updateDS.createOrReplaceTempView(UPDATES_VIEW_NAME); + + sql( + "MERGE INTO %s t USING %s s " + + "ON t.id == s.value " + + "WHEN MATCHED THEN " + + " UPDATE SET id = 100 " + + "WHEN NOT MATCHED THEN " + + " INSERT (id, dep) VALUES (-1, 'unknown')", + targetTableName, UPDATES_VIEW_NAME); + + // there are 2 data files and 2 delete files that apply to both of them + // in CoW, the target table will be scanned 2 times (main query + runtime filter) + // the runtime filter may invalidate the cache so check at least some requests were hits + // in MoR, the target table will be scanned only once + // so each delete file must be opened once + int maxRequestCount = mode == COPY_ON_WRITE ? 3 : 1; + assertThat(deleteFiles).allMatch(deleteFile -> streamCount(deleteFile) <= maxRequestCount); + + // verify the final set of records is correct + assertEquals( + "Should have expected rows", + ImmutableList.of(row(100, "hr"), row(100, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC", targetTableName)); + } + + private int streamCount(DeleteFile deleteFile) { + CustomInputFile inputFile = INPUT_FILES.get(deleteFile.location()); + return inputFile.streamCount(); + } + + private List createAndInitTable(String operation, RowLevelOperationMode mode) + throws Exception { + sql( + "CREATE TABLE %s (id INT, dep STRING) " + + "USING iceberg " + + "TBLPROPERTIES ('%s' '%s', '%s' '%s', '%s' '%s')", + targetTableName, + TableProperties.WRITE_METADATA_LOCATION, + temp.toString().replaceFirst("file:", ""), + TableProperties.WRITE_DATA_LOCATION, + temp.toString().replaceFirst("file:", ""), + operation, + mode.modeName()); + + append(targetTableName, new Employee(0, "hr"), new Employee(1, "hr"), new Employee(2, "hr")); + append(targetTableName, new Employee(3, "hr"), new Employee(4, "hr"), new Employee(5, "hr")); + + Table table = validationCatalog.loadTable(targetTableIdent); + + List> posDeletes = + dataFiles(table).stream() + .map(dataFile -> Pair.of(dataFile.path(), 0L)) + .collect(Collectors.toList()); + Pair posDeleteResult = writePosDeletes(table, posDeletes); + DeleteFile posDeleteFile = posDeleteResult.first(); + CharSequenceSet referencedDataFiles = posDeleteResult.second(); + + DeleteFile eqDeleteFile = writeEqDeletes(table, "id", 2, 5); + + table + .newRowDelta() + .validateFromSnapshot(table.currentSnapshot().snapshotId()) + .validateDataFilesExist(referencedDataFiles) + .addDeletes(posDeleteFile) + .addDeletes(eqDeleteFile) + .commit(); + + sql("REFRESH TABLE %s", targetTableName); + + // invalidate the memory store to destroy all currently live table broadcasts + SparkEnv sparkEnv = SparkEnv.get(); + MemoryStore memoryStore = sparkEnv.blockManager().memoryStore(); + memoryStore.clear(); + + return ImmutableList.of(posDeleteFile, eqDeleteFile); + } + + private DeleteFile writeEqDeletes(Table table, String col, Object... values) throws IOException { + Schema deleteSchema = table.schema().select(col); + + Record delete = GenericRecord.create(deleteSchema); + List deletes = Lists.newArrayList(); + for (Object value : values) { + deletes.add(delete.copy(col, value)); + } + + OutputFile out = Files.localOutput(new File(temp.toFile(), "eq-deletes-" + UUID.randomUUID())); + return FileHelpers.writeDeleteFile(table, out, null, deletes, deleteSchema); + } + + private Pair writePosDeletes( + Table table, List> deletes) throws IOException { + OutputFile out = Files.localOutput(new File(temp.toFile(), "pos-deletes-" + UUID.randomUUID())); + return FileHelpers.writeDeleteFile(table, out, null, deletes); + } + + private void append(String target, Employee... employees) throws NoSuchTableException { + List input = Arrays.asList(employees); + Dataset inputDF = spark.createDataFrame(input, Employee.class); + inputDF.coalesce(1).writeTo(target).append(); + } + + private Collection dataFiles(Table table) { + try (CloseableIterable tasks = table.newScan().planFiles()) { + return ImmutableList.copyOf(Iterables.transform(tasks, ContentScanTask::file)); + } catch (IOException e) { + throw new RuntimeIOException(e); + } + } + + @SuppressWarnings("unchecked") + private static Cache fetchInternalCacheState() { + try { + Field stateField = SparkExecutorCache.class.getDeclaredField("state"); + stateField.setAccessible(true); + SparkExecutorCache cache = SparkExecutorCache.get(); + return (Cache) stateField.get(cache); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private static String toInternalKey(String group, String key) { + return group + "_" + key; + } + + public static class CustomFileIO implements FileIO { + + public CustomFileIO() {} + + @Override + public InputFile newInputFile(String path) { + return INPUT_FILES.computeIfAbsent(path, key -> new CustomInputFile(path)); + } + + @Override + public OutputFile newOutputFile(String path) { + return Files.localOutput(path); + } + + @Override + public void deleteFile(String path) { + File file = new File(path); + if (!file.delete()) { + throw new RuntimeIOException("Failed to delete file: " + path); + } + } + } + + public static class CustomInputFile implements InputFile { + private final InputFile delegate; + private final AtomicInteger streamCount; + + public CustomInputFile(String path) { + this.delegate = Files.localInput(path); + this.streamCount = new AtomicInteger(); + } + + @Override + public long getLength() { + return delegate.getLength(); + } + + @Override + public SeekableInputStream newStream() { + streamCount.incrementAndGet(); + return delegate.newStream(); + } + + public int streamCount() { + return streamCount.get(); + } + + @Override + public String location() { + return delegate.location(); + } + + @Override + public boolean exists() { + return delegate.exists(); + } + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkFilters.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkFilters.java new file mode 100644 index 000000000000..a6205ae9ea3f --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkFilters.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.sql.Date; +import java.sql.Timestamp; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.ZoneId; +import java.time.temporal.ChronoUnit; +import java.util.Map; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.spark.sql.sources.And; +import org.apache.spark.sql.sources.EqualNullSafe; +import org.apache.spark.sql.sources.EqualTo; +import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.GreaterThanOrEqual; +import org.apache.spark.sql.sources.In; +import org.apache.spark.sql.sources.IsNotNull; +import org.apache.spark.sql.sources.IsNull; +import org.apache.spark.sql.sources.LessThan; +import org.apache.spark.sql.sources.LessThanOrEqual; +import org.apache.spark.sql.sources.Not; +import org.junit.jupiter.api.Test; + +public class TestSparkFilters { + + @Test + public void testQuotedAttributes() { + Map attrMap = Maps.newHashMap(); + attrMap.put("id", "id"); + attrMap.put("`i.d`", "i.d"); + attrMap.put("`i``d`", "i`d"); + attrMap.put("`d`.b.`dd```", "d.b.dd`"); + attrMap.put("a.`aa```.c", "a.aa`.c"); + + attrMap.forEach( + (quoted, unquoted) -> { + IsNull isNull = IsNull.apply(quoted); + Expression expectedIsNull = Expressions.isNull(unquoted); + Expression actualIsNull = SparkFilters.convert(isNull); + assertThat(actualIsNull.toString()) + .as("IsNull must match") + .isEqualTo(expectedIsNull.toString()); + + IsNotNull isNotNull = IsNotNull.apply(quoted); + Expression expectedIsNotNull = Expressions.notNull(unquoted); + Expression actualIsNotNull = SparkFilters.convert(isNotNull); + assertThat(actualIsNotNull.toString()) + .as("IsNotNull must match") + .isEqualTo(expectedIsNotNull.toString()); + + LessThan lt = LessThan.apply(quoted, 1); + Expression expectedLt = Expressions.lessThan(unquoted, 1); + Expression actualLt = SparkFilters.convert(lt); + assertThat(actualLt.toString()) + .as("LessThan must match") + .isEqualTo(expectedLt.toString()); + + LessThanOrEqual ltEq = LessThanOrEqual.apply(quoted, 1); + Expression expectedLtEq = Expressions.lessThanOrEqual(unquoted, 1); + Expression actualLtEq = SparkFilters.convert(ltEq); + assertThat(actualLtEq.toString()) + .as("LessThanOrEqual must match") + .isEqualTo(expectedLtEq.toString()); + + GreaterThan gt = GreaterThan.apply(quoted, 1); + Expression expectedGt = Expressions.greaterThan(unquoted, 1); + Expression actualGt = SparkFilters.convert(gt); + assertThat(actualGt.toString()) + .as("GreaterThan must match") + .isEqualTo(expectedGt.toString()); + + GreaterThanOrEqual gtEq = GreaterThanOrEqual.apply(quoted, 1); + Expression expectedGtEq = Expressions.greaterThanOrEqual(unquoted, 1); + Expression actualGtEq = SparkFilters.convert(gtEq); + assertThat(actualGtEq.toString()) + .as("GreaterThanOrEqual must match") + .isEqualTo(expectedGtEq.toString()); + + EqualTo eq = EqualTo.apply(quoted, 1); + Expression expectedEq = Expressions.equal(unquoted, 1); + Expression actualEq = SparkFilters.convert(eq); + assertThat(actualEq.toString()).as("EqualTo must match").isEqualTo(expectedEq.toString()); + + EqualNullSafe eqNullSafe = EqualNullSafe.apply(quoted, 1); + Expression expectedEqNullSafe = Expressions.equal(unquoted, 1); + Expression actualEqNullSafe = SparkFilters.convert(eqNullSafe); + assertThat(actualEqNullSafe.toString()) + .as("EqualNullSafe must match") + .isEqualTo(expectedEqNullSafe.toString()); + + In in = In.apply(quoted, new Integer[] {1}); + Expression expectedIn = Expressions.in(unquoted, 1); + Expression actualIn = SparkFilters.convert(in); + assertThat(actualIn.toString()).as("In must match").isEqualTo(expectedIn.toString()); + }); + } + + @Test + public void testTimestampFilterConversion() { + Instant instant = Instant.parse("2018-10-18T00:00:57.907Z"); + Timestamp timestamp = Timestamp.from(instant); + long epochMicros = ChronoUnit.MICROS.between(Instant.EPOCH, instant); + + Expression instantExpression = SparkFilters.convert(GreaterThan.apply("x", instant)); + Expression timestampExpression = SparkFilters.convert(GreaterThan.apply("x", timestamp)); + Expression rawExpression = Expressions.greaterThan("x", epochMicros); + + assertThat(timestampExpression.toString()) + .as("Generated Timestamp expression should be correct") + .isEqualTo(rawExpression.toString()); + + assertThat(instantExpression.toString()) + .as("Generated Instant expression should be correct") + .isEqualTo(rawExpression.toString()); + } + + @Test + public void testLocalDateTimeFilterConversion() { + LocalDateTime ldt = LocalDateTime.parse("2018-10-18T00:00:57"); + long epochMicros = + ChronoUnit.MICROS.between(LocalDateTime.ofInstant(Instant.EPOCH, ZoneId.of("UTC")), ldt); + + Expression instantExpression = SparkFilters.convert(GreaterThan.apply("x", ldt)); + Expression rawExpression = Expressions.greaterThan("x", epochMicros); + + assertThat(instantExpression.toString()) + .as("Generated Instant expression should be correct") + .isEqualTo(rawExpression.toString()); + } + + @Test + public void testDateFilterConversion() { + LocalDate localDate = LocalDate.parse("2018-10-18"); + Date date = Date.valueOf(localDate); + long epochDay = localDate.toEpochDay(); + + Expression localDateExpression = SparkFilters.convert(GreaterThan.apply("x", localDate)); + Expression dateExpression = SparkFilters.convert(GreaterThan.apply("x", date)); + Expression rawExpression = Expressions.greaterThan("x", epochDay); + + assertThat(localDateExpression.toString()) + .as("Generated localdate expression should be correct") + .isEqualTo(rawExpression.toString()); + + assertThat(dateExpression.toString()) + .as("Generated date expression should be correct") + .isEqualTo(rawExpression.toString()); + } + + @Test + public void testNestedInInsideNot() { + Not filter = + Not.apply(And.apply(EqualTo.apply("col1", 1), In.apply("col2", new Integer[] {1, 2}))); + Expression converted = SparkFilters.convert(filter); + assertThat(converted).as("Expression should not be converted").isNull(); + } + + @Test + public void testNotIn() { + Not filter = Not.apply(In.apply("col", new Integer[] {1, 2})); + Expression actual = SparkFilters.convert(filter); + Expression expected = + Expressions.and(Expressions.notNull("col"), Expressions.notIn("col", 1, 2)); + assertThat(actual.toString()).as("Expressions should match").isEqualTo(expected.toString()); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkSchemaUtil.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkSchemaUtil.java new file mode 100644 index 000000000000..4d4091bf9a9a --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkSchemaUtil.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.IOException; +import java.util.List; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Schema; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.expressions.AttributeReference; +import org.apache.spark.sql.catalyst.expressions.MetadataAttribute; +import org.apache.spark.sql.catalyst.types.DataTypeUtils; +import org.apache.spark.sql.types.StructType; +import org.junit.jupiter.api.Test; + +public class TestSparkSchemaUtil { + private static final Schema TEST_SCHEMA = + new Schema( + optional(1, "id", Types.IntegerType.get()), optional(2, "data", Types.StringType.get())); + + private static final Schema TEST_SCHEMA_WITH_METADATA_COLS = + new Schema( + optional(1, "id", Types.IntegerType.get()), + optional(2, "data", Types.StringType.get()), + MetadataColumns.FILE_PATH, + MetadataColumns.ROW_POSITION); + + @Test + public void testEstimateSizeMaxValue() throws IOException { + assertThat(SparkSchemaUtil.estimateSize(null, Long.MAX_VALUE)) + .as("estimateSize returns Long max value") + .isEqualTo(Long.MAX_VALUE); + } + + @Test + public void testEstimateSizeWithOverflow() throws IOException { + long tableSize = + SparkSchemaUtil.estimateSize(SparkSchemaUtil.convert(TEST_SCHEMA), Long.MAX_VALUE - 1); + assertThat(tableSize).as("estimateSize handles overflow").isEqualTo(Long.MAX_VALUE); + } + + @Test + public void testEstimateSize() throws IOException { + long tableSize = SparkSchemaUtil.estimateSize(SparkSchemaUtil.convert(TEST_SCHEMA), 1); + assertThat(tableSize).as("estimateSize matches with expected approximation").isEqualTo(24); + } + + @Test + public void testSchemaConversionWithMetaDataColumnSchema() { + StructType structType = SparkSchemaUtil.convert(TEST_SCHEMA_WITH_METADATA_COLS); + List attrRefs = + scala.collection.JavaConverters.seqAsJavaList(DataTypeUtils.toAttributes(structType)); + for (AttributeReference attrRef : attrRefs) { + if (MetadataColumns.isMetadataColumn(attrRef.name())) { + assertThat(MetadataAttribute.unapply(attrRef).isDefined()) + .as("metadata columns should have __metadata_col in attribute metadata") + .isTrue(); + } else { + assertThat(MetadataAttribute.unapply(attrRef).isDefined()) + .as("non metadata columns should not have __metadata_col in attribute metadata") + .isFalse(); + } + } + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkSessionCatalog.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkSessionCatalog.java new file mode 100644 index 000000000000..b8062a4a49fe --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkSessionCatalog.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.METASTOREURIS; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class TestSparkSessionCatalog extends TestBase { + private final String envHmsUriKey = "spark.hadoop." + METASTOREURIS.varname; + private final String catalogHmsUriKey = "spark.sql.catalog.spark_catalog.uri"; + private final String hmsUri = hiveConf.get(METASTOREURIS.varname); + + @BeforeAll + public static void setUpCatalog() { + spark + .conf() + .set("spark.sql.catalog.spark_catalog", "org.apache.iceberg.spark.SparkSessionCatalog"); + spark.conf().set("spark.sql.catalog.spark_catalog.type", "hive"); + } + + @BeforeEach + public void setupHmsUri() { + spark.sessionState().catalogManager().reset(); + spark.conf().set(envHmsUriKey, hmsUri); + spark.conf().set(catalogHmsUriKey, hmsUri); + } + + @Test + public void testValidateHmsUri() { + // HMS uris match + assertThat(spark.sessionState().catalogManager().v2SessionCatalog().defaultNamespace()[0]) + .isEqualTo("default"); + + // HMS uris doesn't match + spark.sessionState().catalogManager().reset(); + String catalogHmsUri = "RandomString"; + spark.conf().set(envHmsUriKey, hmsUri); + spark.conf().set(catalogHmsUriKey, catalogHmsUri); + + assertThatThrownBy(() -> spark.sessionState().catalogManager().v2SessionCatalog()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + String.format( + "Inconsistent Hive metastore URIs: %s (Spark session) != %s (spark_catalog)", + hmsUri, catalogHmsUri)); + + // no env HMS uri, only catalog HMS uri + spark.sessionState().catalogManager().reset(); + spark.conf().set(catalogHmsUriKey, hmsUri); + spark.conf().unset(envHmsUriKey); + assertThat(spark.sessionState().catalogManager().v2SessionCatalog().defaultNamespace()[0]) + .isEqualTo("default"); + + // no catalog HMS uri, only env HMS uri + spark.sessionState().catalogManager().reset(); + spark.conf().set(envHmsUriKey, hmsUri); + spark.conf().unset(catalogHmsUriKey); + assertThat(spark.sessionState().catalogManager().v2SessionCatalog().defaultNamespace()[0]) + .isEqualTo("default"); + } + + @Test + public void testLoadFunction() { + String functionClass = "org.apache.hadoop.hive.ql.udf.generic.GenericUDFUpper"; + + // load permanent UDF in Hive via FunctionCatalog + spark.sql(String.format("CREATE FUNCTION perm_upper AS '%s'", functionClass)); + assertThat(scalarSql("SELECT perm_upper('xyz')")) + .as("Load permanent UDF in Hive") + .isEqualTo("XYZ"); + + // load temporary UDF in Hive via FunctionCatalog + spark.sql(String.format("CREATE TEMPORARY FUNCTION temp_upper AS '%s'", functionClass)); + assertThat(scalarSql("SELECT temp_upper('xyz')")) + .as("Load temporary UDF in Hive") + .isEqualTo("XYZ"); + + // TODO: fix loading Iceberg built-in functions in SessionCatalog + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkTableUtil.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkTableUtil.java new file mode 100644 index 000000000000..772ae3a224ac --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkTableUtil.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.KryoHelpers; +import org.apache.iceberg.MetricsConfig; +import org.apache.iceberg.MetricsModes; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.TestHelpers; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkTableUtil.SparkPartition; +import org.junit.jupiter.api.Test; + +public class TestSparkTableUtil { + @Test + public void testSparkPartitionOKryoSerialization() throws IOException { + Map values = ImmutableMap.of("id", "2"); + String uri = "s3://bucket/table/data/id=2"; + String format = "parquet"; + SparkPartition sparkPartition = new SparkPartition(values, uri, format); + + SparkPartition deserialized = KryoHelpers.roundTripSerialize(sparkPartition); + assertThat(sparkPartition).isEqualTo(deserialized); + } + + @Test + public void testSparkPartitionJavaSerialization() throws IOException, ClassNotFoundException { + Map values = ImmutableMap.of("id", "2"); + String uri = "s3://bucket/table/data/id=2"; + String format = "parquet"; + SparkPartition sparkPartition = new SparkPartition(values, uri, format); + + SparkPartition deserialized = TestHelpers.roundTripSerialize(sparkPartition); + assertThat(sparkPartition).isEqualTo(deserialized); + } + + @Test + public void testMetricsConfigKryoSerialization() throws Exception { + Map metricsConfig = + ImmutableMap.of( + TableProperties.DEFAULT_WRITE_METRICS_MODE, + "counts", + TableProperties.METRICS_MODE_COLUMN_CONF_PREFIX + "col1", + "full", + TableProperties.METRICS_MODE_COLUMN_CONF_PREFIX + "col2", + "truncate(16)"); + + MetricsConfig config = MetricsConfig.fromProperties(metricsConfig); + MetricsConfig deserialized = KryoHelpers.roundTripSerialize(config); + + assertThat(deserialized.columnMode("col1").toString()) + .isEqualTo(MetricsModes.Full.get().toString()); + assertThat(deserialized.columnMode("col2").toString()) + .isEqualTo(MetricsModes.Truncate.withLength(16).toString()); + assertThat(deserialized.columnMode("col3").toString()) + .isEqualTo(MetricsModes.Counts.get().toString()); + } + + @Test + public void testMetricsConfigJavaSerialization() throws Exception { + Map metricsConfig = + ImmutableMap.of( + TableProperties.DEFAULT_WRITE_METRICS_MODE, + "counts", + TableProperties.METRICS_MODE_COLUMN_CONF_PREFIX + "col1", + "full", + TableProperties.METRICS_MODE_COLUMN_CONF_PREFIX + "col2", + "truncate(16)"); + + MetricsConfig config = MetricsConfig.fromProperties(metricsConfig); + MetricsConfig deserialized = TestHelpers.roundTripSerialize(config); + + assertThat(deserialized.columnMode("col1").toString()) + .isEqualTo(MetricsModes.Full.get().toString()); + assertThat(deserialized.columnMode("col2").toString()) + .isEqualTo(MetricsModes.Truncate.withLength(16).toString()); + assertThat(deserialized.columnMode("col3").toString()) + .isEqualTo(MetricsModes.Counts.get().toString()); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkV2Filters.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkV2Filters.java new file mode 100644 index 000000000000..44fb64120ca0 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkV2Filters.java @@ -0,0 +1,834 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.time.Instant; +import java.time.LocalDate; +import java.time.temporal.ChronoUnit; +import java.util.Map; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.ExpressionUtil; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.expressions.UnboundTerm; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.functions.BucketFunction; +import org.apache.iceberg.spark.functions.DaysFunction; +import org.apache.iceberg.spark.functions.HoursFunction; +import org.apache.iceberg.spark.functions.IcebergVersionFunction; +import org.apache.iceberg.spark.functions.MonthsFunction; +import org.apache.iceberg.spark.functions.TruncateFunction; +import org.apache.iceberg.spark.functions.YearsFunction; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.DateTimeUtil; +import org.apache.spark.sql.connector.catalog.functions.ScalarFunction; +import org.apache.spark.sql.connector.expressions.FieldReference; +import org.apache.spark.sql.connector.expressions.LiteralValue; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.UserDefinedScalarFunc; +import org.apache.spark.sql.connector.expressions.filter.And; +import org.apache.spark.sql.connector.expressions.filter.Not; +import org.apache.spark.sql.connector.expressions.filter.Or; +import org.apache.spark.sql.connector.expressions.filter.Predicate; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; +import org.junit.jupiter.api.Test; + +public class TestSparkV2Filters { + + private static final Types.StructType STRUCT = + Types.StructType.of( + Types.NestedField.optional(1, "dateCol", Types.DateType.get()), + Types.NestedField.optional(2, "tsCol", Types.TimestampType.withZone()), + Types.NestedField.optional(3, "tsNtzCol", Types.TimestampType.withoutZone()), + Types.NestedField.optional(4, "intCol", Types.IntegerType.get()), + Types.NestedField.optional(5, "strCol", Types.StringType.get())); + + @SuppressWarnings("checkstyle:MethodLength") + @Test + public void testV2Filters() { + Map attrMap = Maps.newHashMap(); + attrMap.put("id", "id"); + attrMap.put("`i.d`", "i.d"); + attrMap.put("`i``d`", "i`d"); + attrMap.put("`d`.b.`dd```", "d.b.dd`"); + attrMap.put("a.`aa```.c", "a.aa`.c"); + + attrMap.forEach( + (quoted, unquoted) -> { + NamedReference namedReference = FieldReference.apply(quoted); + org.apache.spark.sql.connector.expressions.Expression[] attrOnly = + new org.apache.spark.sql.connector.expressions.Expression[] {namedReference}; + + LiteralValue value = new LiteralValue(1, DataTypes.IntegerType); + org.apache.spark.sql.connector.expressions.Expression[] attrAndValue = + new org.apache.spark.sql.connector.expressions.Expression[] {namedReference, value}; + org.apache.spark.sql.connector.expressions.Expression[] valueAndAttr = + new org.apache.spark.sql.connector.expressions.Expression[] {value, namedReference}; + + Predicate isNull = new Predicate("IS_NULL", attrOnly); + Expression expectedIsNull = Expressions.isNull(unquoted); + Expression actualIsNull = SparkV2Filters.convert(isNull); + assertThat(actualIsNull.toString()) + .as("IsNull must match") + .isEqualTo(expectedIsNull.toString()); + + Predicate isNotNull = new Predicate("IS_NOT_NULL", attrOnly); + Expression expectedIsNotNull = Expressions.notNull(unquoted); + Expression actualIsNotNull = SparkV2Filters.convert(isNotNull); + assertThat(actualIsNotNull.toString()) + .as("IsNotNull must match") + .isEqualTo(expectedIsNotNull.toString()); + + Predicate lt1 = new Predicate("<", attrAndValue); + Expression expectedLt1 = Expressions.lessThan(unquoted, 1); + Expression actualLt1 = SparkV2Filters.convert(lt1); + assertThat(actualLt1.toString()) + .as("LessThan must match") + .isEqualTo(expectedLt1.toString()); + + Predicate lt2 = new Predicate("<", valueAndAttr); + Expression expectedLt2 = Expressions.greaterThan(unquoted, 1); + Expression actualLt2 = SparkV2Filters.convert(lt2); + assertThat(actualLt2.toString()) + .as("LessThan must match") + .isEqualTo(expectedLt2.toString()); + + Predicate ltEq1 = new Predicate("<=", attrAndValue); + Expression expectedLtEq1 = Expressions.lessThanOrEqual(unquoted, 1); + Expression actualLtEq1 = SparkV2Filters.convert(ltEq1); + assertThat(actualLtEq1.toString()) + .as("LessThanOrEqual must match") + .isEqualTo(expectedLtEq1.toString()); + + Predicate ltEq2 = new Predicate("<=", valueAndAttr); + Expression expectedLtEq2 = Expressions.greaterThanOrEqual(unquoted, 1); + Expression actualLtEq2 = SparkV2Filters.convert(ltEq2); + assertThat(actualLtEq2.toString()) + .as("LessThanOrEqual must match") + .isEqualTo(expectedLtEq2.toString()); + + Predicate gt1 = new Predicate(">", attrAndValue); + Expression expectedGt1 = Expressions.greaterThan(unquoted, 1); + Expression actualGt1 = SparkV2Filters.convert(gt1); + assertThat(actualGt1.toString()) + .as("GreaterThan must match") + .isEqualTo(expectedGt1.toString()); + + Predicate gt2 = new Predicate(">", valueAndAttr); + Expression expectedGt2 = Expressions.lessThan(unquoted, 1); + Expression actualGt2 = SparkV2Filters.convert(gt2); + assertThat(actualGt2.toString()) + .as("GreaterThan must match") + .isEqualTo(expectedGt2.toString()); + + Predicate gtEq1 = new Predicate(">=", attrAndValue); + Expression expectedGtEq1 = Expressions.greaterThanOrEqual(unquoted, 1); + Expression actualGtEq1 = SparkV2Filters.convert(gtEq1); + assertThat(actualGtEq1.toString()) + .as("GreaterThanOrEqual must match") + .isEqualTo(expectedGtEq1.toString()); + + Predicate gtEq2 = new Predicate(">=", valueAndAttr); + Expression expectedGtEq2 = Expressions.lessThanOrEqual(unquoted, 1); + Expression actualGtEq2 = SparkV2Filters.convert(gtEq2); + assertThat(actualGtEq2.toString()) + .as("GreaterThanOrEqual must match") + .isEqualTo(expectedGtEq2.toString()); + + Predicate eq1 = new Predicate("=", attrAndValue); + Expression expectedEq1 = Expressions.equal(unquoted, 1); + Expression actualEq1 = SparkV2Filters.convert(eq1); + assertThat(actualEq1.toString()) + .as("EqualTo must match") + .isEqualTo(expectedEq1.toString()); + + Predicate eq2 = new Predicate("=", valueAndAttr); + Expression expectedEq2 = Expressions.equal(unquoted, 1); + Expression actualEq2 = SparkV2Filters.convert(eq2); + assertThat(actualEq2.toString()) + .as("EqualTo must match") + .isEqualTo(expectedEq2.toString()); + + Predicate notEq1 = new Predicate("<>", attrAndValue); + Expression expectedNotEq1 = Expressions.notEqual(unquoted, 1); + Expression actualNotEq1 = SparkV2Filters.convert(notEq1); + assertThat(actualNotEq1.toString()) + .as("NotEqualTo must match") + .isEqualTo(expectedNotEq1.toString()); + + Predicate notEq2 = new Predicate("<>", valueAndAttr); + Expression expectedNotEq2 = Expressions.notEqual(unquoted, 1); + Expression actualNotEq2 = SparkV2Filters.convert(notEq2); + assertThat(actualNotEq2.toString()) + .as("NotEqualTo must match") + .isEqualTo(expectedNotEq2.toString()); + + Predicate eqNullSafe1 = new Predicate("<=>", attrAndValue); + Expression expectedEqNullSafe1 = Expressions.equal(unquoted, 1); + Expression actualEqNullSafe1 = SparkV2Filters.convert(eqNullSafe1); + assertThat(actualEqNullSafe1.toString()) + .as("EqualNullSafe must match") + .isEqualTo(expectedEqNullSafe1.toString()); + + Predicate eqNullSafe2 = new Predicate("<=>", valueAndAttr); + Expression expectedEqNullSafe2 = Expressions.equal(unquoted, 1); + Expression actualEqNullSafe2 = SparkV2Filters.convert(eqNullSafe2); + assertThat(actualEqNullSafe2.toString()) + .as("EqualNullSafe must match") + .isEqualTo(expectedEqNullSafe2.toString()); + + LiteralValue str = + new LiteralValue(UTF8String.fromString("iceberg"), DataTypes.StringType); + org.apache.spark.sql.connector.expressions.Expression[] attrAndStr = + new org.apache.spark.sql.connector.expressions.Expression[] {namedReference, str}; + Predicate startsWith = new Predicate("STARTS_WITH", attrAndStr); + Expression expectedStartsWith = Expressions.startsWith(unquoted, "iceberg"); + Expression actualStartsWith = SparkV2Filters.convert(startsWith); + assertThat(actualStartsWith.toString()) + .as("StartsWith must match") + .isEqualTo(expectedStartsWith.toString()); + + Predicate in = new Predicate("IN", attrAndValue); + Expression expectedIn = Expressions.in(unquoted, 1); + Expression actualIn = SparkV2Filters.convert(in); + assertThat(actualIn.toString()).as("In must match").isEqualTo(expectedIn.toString()); + + Predicate and = new And(lt1, eq1); + Expression expectedAnd = Expressions.and(expectedLt1, expectedEq1); + Expression actualAnd = SparkV2Filters.convert(and); + assertThat(actualAnd.toString()).as("And must match").isEqualTo(expectedAnd.toString()); + + org.apache.spark.sql.connector.expressions.Expression[] attrAndAttr = + new org.apache.spark.sql.connector.expressions.Expression[] { + namedReference, namedReference + }; + Predicate invalid = new Predicate("<", attrAndAttr); + Predicate andWithInvalidLeft = new And(invalid, eq1); + Expression convertedAnd = SparkV2Filters.convert(andWithInvalidLeft); + assertThat(convertedAnd).as("And must match").isNull(); + + Predicate or = new Or(lt1, eq1); + Expression expectedOr = Expressions.or(expectedLt1, expectedEq1); + Expression actualOr = SparkV2Filters.convert(or); + assertThat(actualOr.toString()).as("Or must match").isEqualTo(expectedOr.toString()); + + Predicate orWithInvalidLeft = new Or(invalid, eq1); + Expression convertedOr = SparkV2Filters.convert(orWithInvalidLeft); + assertThat(convertedOr).as("Or must match").isNull(); + + Predicate not = new Not(lt1); + Expression expectedNot = Expressions.not(expectedLt1); + Expression actualNot = SparkV2Filters.convert(not); + assertThat(actualNot.toString()).as("Not must match").isEqualTo(expectedNot.toString()); + }); + } + + @Test + public void testEqualToNull() { + String col = "col"; + NamedReference namedReference = FieldReference.apply(col); + LiteralValue value = new LiteralValue(null, DataTypes.IntegerType); + + org.apache.spark.sql.connector.expressions.Expression[] attrAndValue = + new org.apache.spark.sql.connector.expressions.Expression[] {namedReference, value}; + org.apache.spark.sql.connector.expressions.Expression[] valueAndAttr = + new org.apache.spark.sql.connector.expressions.Expression[] {value, namedReference}; + + Predicate eq1 = new Predicate("=", attrAndValue); + assertThatThrownBy(() -> SparkV2Filters.convert(eq1)) + .isInstanceOf(NullPointerException.class) + .hasMessageContaining("Expression is always false"); + + Predicate eq2 = new Predicate("=", valueAndAttr); + assertThatThrownBy(() -> SparkV2Filters.convert(eq2)) + .isInstanceOf(NullPointerException.class) + .hasMessageContaining("Expression is always false"); + + Predicate eqNullSafe1 = new Predicate("<=>", attrAndValue); + Expression expectedEqNullSafe = Expressions.isNull(col); + Expression actualEqNullSafe1 = SparkV2Filters.convert(eqNullSafe1); + assertThat(actualEqNullSafe1.toString()).isEqualTo(expectedEqNullSafe.toString()); + + Predicate eqNullSafe2 = new Predicate("<=>", valueAndAttr); + Expression actualEqNullSafe2 = SparkV2Filters.convert(eqNullSafe2); + assertThat(actualEqNullSafe2.toString()).isEqualTo(expectedEqNullSafe.toString()); + } + + @Test + public void testEqualToNaN() { + String col = "col"; + NamedReference namedReference = FieldReference.apply(col); + LiteralValue value = new LiteralValue(Float.NaN, DataTypes.FloatType); + + org.apache.spark.sql.connector.expressions.Expression[] attrAndValue = + new org.apache.spark.sql.connector.expressions.Expression[] {namedReference, value}; + org.apache.spark.sql.connector.expressions.Expression[] valueAndAttr = + new org.apache.spark.sql.connector.expressions.Expression[] {value, namedReference}; + + Predicate eqNaN1 = new Predicate("=", attrAndValue); + Expression expectedEqNaN = Expressions.isNaN(col); + Expression actualEqNaN1 = SparkV2Filters.convert(eqNaN1); + assertThat(actualEqNaN1.toString()).isEqualTo(expectedEqNaN.toString()); + + Predicate eqNaN2 = new Predicate("=", valueAndAttr); + Expression actualEqNaN2 = SparkV2Filters.convert(eqNaN2); + assertThat(actualEqNaN2.toString()).isEqualTo(expectedEqNaN.toString()); + } + + @Test + public void testNotEqualToNull() { + String col = "col"; + NamedReference namedReference = FieldReference.apply(col); + LiteralValue value = new LiteralValue(null, DataTypes.IntegerType); + + org.apache.spark.sql.connector.expressions.Expression[] attrAndValue = + new org.apache.spark.sql.connector.expressions.Expression[] {namedReference, value}; + org.apache.spark.sql.connector.expressions.Expression[] valueAndAttr = + new org.apache.spark.sql.connector.expressions.Expression[] {value, namedReference}; + + Predicate notEq1 = new Predicate("<>", attrAndValue); + assertThatThrownBy(() -> SparkV2Filters.convert(notEq1)) + .isInstanceOf(NullPointerException.class) + .hasMessageContaining("Expression is always false"); + + Predicate notEq2 = new Predicate("<>", valueAndAttr); + assertThatThrownBy(() -> SparkV2Filters.convert(notEq2)) + .isInstanceOf(NullPointerException.class) + .hasMessageContaining("Expression is always false"); + } + + @Test + public void testNotEqualToNaN() { + String col = "col"; + NamedReference namedReference = FieldReference.apply(col); + LiteralValue value = new LiteralValue(Float.NaN, DataTypes.FloatType); + + org.apache.spark.sql.connector.expressions.Expression[] attrAndValue = + new org.apache.spark.sql.connector.expressions.Expression[] {namedReference, value}; + org.apache.spark.sql.connector.expressions.Expression[] valueAndAttr = + new org.apache.spark.sql.connector.expressions.Expression[] {value, namedReference}; + + Predicate notEqNaN1 = new Predicate("<>", attrAndValue); + Expression expectedNotEqNaN = Expressions.notNaN(col); + Expression actualNotEqNaN1 = SparkV2Filters.convert(notEqNaN1); + assertThat(actualNotEqNaN1.toString()).isEqualTo(expectedNotEqNaN.toString()); + + Predicate notEqNaN2 = new Predicate("<>", valueAndAttr); + Expression actualNotEqNaN2 = SparkV2Filters.convert(notEqNaN2); + assertThat(actualNotEqNaN2.toString()).isEqualTo(expectedNotEqNaN.toString()); + } + + @Test + public void testInValuesContainNull() { + String col = "strCol"; + NamedReference namedReference = FieldReference.apply(col); + LiteralValue nullValue = new LiteralValue(null, DataTypes.StringType); + LiteralValue value1 = new LiteralValue("value1", DataTypes.StringType); + LiteralValue value2 = new LiteralValue("value2", DataTypes.StringType); + + // Values only contains null + Predicate inNull = new Predicate("IN", expressions(namedReference, nullValue)); + Expression expectedInNull = Expressions.in(col); + Expression actualInNull = SparkV2Filters.convert(inNull); + assertEquals(expectedInNull, actualInNull); + + Predicate in = new Predicate("IN", expressions(namedReference, nullValue, value1, value2)); + Expression expectedIn = Expressions.in(col, "value1", "value2"); + Expression actualIn = SparkV2Filters.convert(in); + assertEquals(expectedIn, actualIn); + } + + @Test + public void testNotInNull() { + String col = "strCol"; + NamedReference namedReference = FieldReference.apply(col); + LiteralValue nullValue = new LiteralValue(null, DataTypes.StringType); + LiteralValue value1 = new LiteralValue("value1", DataTypes.StringType); + LiteralValue value2 = new LiteralValue("value2", DataTypes.StringType); + + // Values only contains null + Predicate notInNull = new Not(new Predicate("IN", expressions(namedReference, nullValue))); + Expression expectedNotInNull = + Expressions.and(Expressions.notNull(col), Expressions.notIn(col)); + Expression actualNotInNull = SparkV2Filters.convert(notInNull); + assertEquals(expectedNotInNull, actualNotInNull); + + Predicate notIn = + new Not(new Predicate("IN", expressions(namedReference, nullValue, value1, value2))); + Expression expectedNotIn = + Expressions.and(Expressions.notNull(col), Expressions.notIn(col, "value1", "value2")); + Expression actualNotIn = SparkV2Filters.convert(notIn); + assertEquals(expectedNotIn, actualNotIn); + } + + @Test + public void testTimestampFilterConversion() { + Instant instant = Instant.parse("2018-10-18T00:00:57.907Z"); + long epochMicros = ChronoUnit.MICROS.between(Instant.EPOCH, instant); + + NamedReference namedReference = FieldReference.apply("x"); + LiteralValue ts = new LiteralValue(epochMicros, DataTypes.TimestampType); + org.apache.spark.sql.connector.expressions.Expression[] attrAndValue = + new org.apache.spark.sql.connector.expressions.Expression[] {namedReference, ts}; + + Predicate predicate = new Predicate(">", attrAndValue); + Expression tsExpression = SparkV2Filters.convert(predicate); + Expression rawExpression = Expressions.greaterThan("x", epochMicros); + + assertThat(tsExpression.toString()) + .as("Generated Timestamp expression should be correct") + .isEqualTo(rawExpression.toString()); + } + + @Test + public void testDateFilterConversion() { + LocalDate localDate = LocalDate.parse("2018-10-18"); + long epochDay = localDate.toEpochDay(); + + NamedReference namedReference = FieldReference.apply("x"); + LiteralValue ts = new LiteralValue(epochDay, DataTypes.DateType); + org.apache.spark.sql.connector.expressions.Expression[] attrAndValue = + new org.apache.spark.sql.connector.expressions.Expression[] {namedReference, ts}; + + Predicate predicate = new Predicate(">", attrAndValue); + Expression dateExpression = SparkV2Filters.convert(predicate); + Expression rawExpression = Expressions.greaterThan("x", epochDay); + + assertThat(dateExpression.toString()) + .as("Generated date expression should be correct") + .isEqualTo(rawExpression.toString()); + } + + @Test + public void testNestedInInsideNot() { + NamedReference namedReference1 = FieldReference.apply("col1"); + LiteralValue v1 = new LiteralValue(1, DataTypes.IntegerType); + LiteralValue v2 = new LiteralValue(2, DataTypes.IntegerType); + org.apache.spark.sql.connector.expressions.Expression[] attrAndValue1 = + new org.apache.spark.sql.connector.expressions.Expression[] {namedReference1, v1}; + Predicate equal = new Predicate("=", attrAndValue1); + + NamedReference namedReference2 = FieldReference.apply("col2"); + org.apache.spark.sql.connector.expressions.Expression[] attrAndValue2 = + new org.apache.spark.sql.connector.expressions.Expression[] {namedReference2, v1, v2}; + Predicate in = new Predicate("IN", attrAndValue2); + + Not filter = new Not(new And(equal, in)); + Expression converted = SparkV2Filters.convert(filter); + assertThat(converted).as("Expression should not be converted").isNull(); + } + + @Test + public void testNotIn() { + NamedReference namedReference = FieldReference.apply("col"); + LiteralValue v1 = new LiteralValue(1, DataTypes.IntegerType); + LiteralValue v2 = new LiteralValue(2, DataTypes.IntegerType); + org.apache.spark.sql.connector.expressions.Expression[] attrAndValue = + new org.apache.spark.sql.connector.expressions.Expression[] {namedReference, v1, v2}; + + Predicate in = new Predicate("IN", attrAndValue); + Not not = new Not(in); + + Expression actual = SparkV2Filters.convert(not); + Expression expected = + Expressions.and(Expressions.notNull("col"), Expressions.notIn("col", 1, 2)); + assertThat(actual.toString()).as("Expressions should match").isEqualTo(expected.toString()); + } + + @Test + public void testDateToYears() { + ScalarFunction dateToYearsFunc = new YearsFunction.DateToYearsFunction(); + UserDefinedScalarFunc udf = + new UserDefinedScalarFunc( + dateToYearsFunc.name(), + dateToYearsFunc.canonicalName(), + expressions(FieldReference.apply("dateCol"))); + testUDF(udf, Expressions.year("dateCol"), dateToYears("2023-06-25"), DataTypes.IntegerType); + } + + @Test + public void testTsToYears() { + ScalarFunction tsToYearsFunc = new YearsFunction.TimestampToYearsFunction(); + UserDefinedScalarFunc udf = + new UserDefinedScalarFunc( + tsToYearsFunc.name(), + tsToYearsFunc.canonicalName(), + expressions(FieldReference.apply("tsCol"))); + testUDF( + udf, + Expressions.year("tsCol"), + timestampToYears("2023-12-03T10:15:30+01:00"), + DataTypes.IntegerType); + } + + @Test + public void testTsNtzToYears() { + ScalarFunction tsNtzToYearsFunc = new YearsFunction.TimestampNtzToYearsFunction(); + UserDefinedScalarFunc udf = + new UserDefinedScalarFunc( + tsNtzToYearsFunc.name(), + tsNtzToYearsFunc.canonicalName(), + expressions(FieldReference.apply("tsNtzCol"))); + testUDF( + udf, + Expressions.year("tsNtzCol"), + timestampNtzToYears("2023-06-25T13:15:30"), + DataTypes.IntegerType); + } + + @Test + public void testDateToMonths() { + ScalarFunction dateToMonthsFunc = new MonthsFunction.DateToMonthsFunction(); + UserDefinedScalarFunc udf = + new UserDefinedScalarFunc( + dateToMonthsFunc.name(), + dateToMonthsFunc.canonicalName(), + expressions(FieldReference.apply("dateCol"))); + testUDF(udf, Expressions.month("dateCol"), dateToMonths("2023-06-25"), DataTypes.IntegerType); + } + + @Test + public void testTsToMonths() { + ScalarFunction tsToMonthsFunc = new MonthsFunction.TimestampToMonthsFunction(); + UserDefinedScalarFunc udf = + new UserDefinedScalarFunc( + tsToMonthsFunc.name(), + tsToMonthsFunc.canonicalName(), + expressions(FieldReference.apply("tsCol"))); + testUDF( + udf, + Expressions.month("tsCol"), + timestampToMonths("2023-12-03T10:15:30+01:00"), + DataTypes.IntegerType); + } + + @Test + public void testTsNtzToMonths() { + ScalarFunction tsNtzToMonthsFunc = new MonthsFunction.TimestampNtzToMonthsFunction(); + UserDefinedScalarFunc udf = + new UserDefinedScalarFunc( + tsNtzToMonthsFunc.name(), + tsNtzToMonthsFunc.canonicalName(), + expressions(FieldReference.apply("tsNtzCol"))); + testUDF( + udf, + Expressions.month("tsNtzCol"), + timestampNtzToMonths("2023-12-03T10:15:30"), + DataTypes.IntegerType); + } + + @Test + public void testDateToDays() { + ScalarFunction dateToDayFunc = new DaysFunction.DateToDaysFunction(); + UserDefinedScalarFunc udf = + new UserDefinedScalarFunc( + dateToDayFunc.name(), + dateToDayFunc.canonicalName(), + expressions(FieldReference.apply("dateCol"))); + testUDF(udf, Expressions.day("dateCol"), dateToDays("2023-06-25"), DataTypes.IntegerType); + } + + @Test + public void testTsToDays() { + ScalarFunction tsToDaysFunc = new DaysFunction.TimestampToDaysFunction(); + UserDefinedScalarFunc udf = + new UserDefinedScalarFunc( + tsToDaysFunc.name(), + tsToDaysFunc.canonicalName(), + expressions(FieldReference.apply("tsCol"))); + testUDF( + udf, + Expressions.day("tsCol"), + timestampToDays("2023-12-03T10:15:30+01:00"), + DataTypes.IntegerType); + } + + @Test + public void testTsNtzToDays() { + ScalarFunction tsNtzToDaysFunc = new DaysFunction.TimestampNtzToDaysFunction(); + UserDefinedScalarFunc udf = + new UserDefinedScalarFunc( + tsNtzToDaysFunc.name(), + tsNtzToDaysFunc.canonicalName(), + expressions(FieldReference.apply("tsNtzCol"))); + testUDF( + udf, + Expressions.day("tsNtzCol"), + timestampNtzToDays("2023-12-03T10:15:30"), + DataTypes.IntegerType); + } + + @Test + public void testTsToHours() { + ScalarFunction tsToHourFunc = new HoursFunction.TimestampToHoursFunction(); + UserDefinedScalarFunc udf = + new UserDefinedScalarFunc( + tsToHourFunc.name(), + tsToHourFunc.canonicalName(), + expressions(FieldReference.apply("tsCol"))); + testUDF( + udf, + Expressions.hour("tsCol"), + timestampToHours("2023-12-03T10:15:30+01:00"), + DataTypes.IntegerType); + } + + @Test + public void testTsNtzToHours() { + ScalarFunction tsNtzToHourFunc = new HoursFunction.TimestampNtzToHoursFunction(); + UserDefinedScalarFunc udf = + new UserDefinedScalarFunc( + tsNtzToHourFunc.name(), + tsNtzToHourFunc.canonicalName(), + expressions(FieldReference.apply("tsNtzCol"))); + testUDF( + udf, + Expressions.hour("tsNtzCol"), + timestampNtzToHours("2023-12-03T10:15:30"), + DataTypes.IntegerType); + } + + @Test + public void testBucket() { + ScalarFunction bucketInt = new BucketFunction.BucketInt(DataTypes.IntegerType); + UserDefinedScalarFunc udf = + new UserDefinedScalarFunc( + bucketInt.name(), + bucketInt.canonicalName(), + expressions( + LiteralValue.apply(4, DataTypes.IntegerType), FieldReference.apply("intCol"))); + testUDF(udf, Expressions.bucket("intCol", 4), 2, DataTypes.IntegerType); + } + + @Test + public void testTruncate() { + ScalarFunction truncate = new TruncateFunction.TruncateString(); + UserDefinedScalarFunc udf = + new UserDefinedScalarFunc( + truncate.name(), + truncate.canonicalName(), + expressions( + LiteralValue.apply(6, DataTypes.IntegerType), FieldReference.apply("strCol"))); + testUDF(udf, Expressions.truncate("strCol", 6), "prefix", DataTypes.StringType); + } + + @Test + public void testUnsupportedUDFConvert() { + ScalarFunction icebergVersionFunc = + (ScalarFunction) new IcebergVersionFunction().bind(new StructType()); + UserDefinedScalarFunc udf = + new UserDefinedScalarFunc( + icebergVersionFunc.name(), + icebergVersionFunc.canonicalName(), + new org.apache.spark.sql.connector.expressions.Expression[] {}); + LiteralValue literalValue = new LiteralValue("1.3.0", DataTypes.StringType); + Predicate predicate = new Predicate("=", expressions(udf, literalValue)); + + Expression icebergExpr = SparkV2Filters.convert(predicate); + assertThat(icebergExpr).isNull(); + } + + private void testUDF( + org.apache.spark.sql.connector.expressions.Expression udf, + UnboundTerm expectedTerm, + T value, + DataType dataType) { + org.apache.spark.sql.connector.expressions.Expression[] attrOnly = expressions(udf); + + LiteralValue literalValue = new LiteralValue(value, dataType); + org.apache.spark.sql.connector.expressions.Expression[] attrAndValue = + expressions(udf, literalValue); + org.apache.spark.sql.connector.expressions.Expression[] valueAndAttr = + expressions(literalValue, udf); + + Predicate isNull = new Predicate("IS_NULL", attrOnly); + Expression expectedIsNull = Expressions.isNull(expectedTerm); + Expression actualIsNull = SparkV2Filters.convert(isNull); + assertEquals(expectedIsNull, actualIsNull); + + Predicate isNotNull = new Predicate("IS_NOT_NULL", attrOnly); + Expression expectedIsNotNull = Expressions.notNull(expectedTerm); + Expression actualIsNotNull = SparkV2Filters.convert(isNotNull); + assertEquals(expectedIsNotNull, actualIsNotNull); + + Predicate lt1 = new Predicate("<", attrAndValue); + Expression expectedLt1 = Expressions.lessThan(expectedTerm, value); + Expression actualLt1 = SparkV2Filters.convert(lt1); + assertEquals(expectedLt1, actualLt1); + + Predicate lt2 = new Predicate("<", valueAndAttr); + Expression expectedLt2 = Expressions.greaterThan(expectedTerm, value); + Expression actualLt2 = SparkV2Filters.convert(lt2); + assertEquals(expectedLt2, actualLt2); + + Predicate ltEq1 = new Predicate("<=", attrAndValue); + Expression expectedLtEq1 = Expressions.lessThanOrEqual(expectedTerm, value); + Expression actualLtEq1 = SparkV2Filters.convert(ltEq1); + assertEquals(expectedLtEq1, actualLtEq1); + + Predicate ltEq2 = new Predicate("<=", valueAndAttr); + Expression expectedLtEq2 = Expressions.greaterThanOrEqual(expectedTerm, value); + Expression actualLtEq2 = SparkV2Filters.convert(ltEq2); + assertEquals(expectedLtEq2, actualLtEq2); + + Predicate gt1 = new Predicate(">", attrAndValue); + Expression expectedGt1 = Expressions.greaterThan(expectedTerm, value); + Expression actualGt1 = SparkV2Filters.convert(gt1); + assertEquals(expectedGt1, actualGt1); + + Predicate gt2 = new Predicate(">", valueAndAttr); + Expression expectedGt2 = Expressions.lessThan(expectedTerm, value); + Expression actualGt2 = SparkV2Filters.convert(gt2); + assertEquals(expectedGt2, actualGt2); + + Predicate gtEq1 = new Predicate(">=", attrAndValue); + Expression expectedGtEq1 = Expressions.greaterThanOrEqual(expectedTerm, value); + Expression actualGtEq1 = SparkV2Filters.convert(gtEq1); + assertEquals(expectedGtEq1, actualGtEq1); + + Predicate gtEq2 = new Predicate(">=", valueAndAttr); + Expression expectedGtEq2 = Expressions.lessThanOrEqual(expectedTerm, value); + Expression actualGtEq2 = SparkV2Filters.convert(gtEq2); + assertEquals(expectedGtEq2, actualGtEq2); + + Predicate eq1 = new Predicate("=", attrAndValue); + Expression expectedEq1 = Expressions.equal(expectedTerm, value); + Expression actualEq1 = SparkV2Filters.convert(eq1); + assertEquals(expectedEq1, actualEq1); + + Predicate eq2 = new Predicate("=", valueAndAttr); + Expression expectedEq2 = Expressions.equal(expectedTerm, value); + Expression actualEq2 = SparkV2Filters.convert(eq2); + assertEquals(expectedEq2, actualEq2); + + Predicate notEq1 = new Predicate("<>", attrAndValue); + Expression expectedNotEq1 = Expressions.notEqual(expectedTerm, value); + Expression actualNotEq1 = SparkV2Filters.convert(notEq1); + assertEquals(expectedNotEq1, actualNotEq1); + + Predicate notEq2 = new Predicate("<>", valueAndAttr); + Expression expectedNotEq2 = Expressions.notEqual(expectedTerm, value); + Expression actualNotEq2 = SparkV2Filters.convert(notEq2); + assertEquals(expectedNotEq2, actualNotEq2); + + Predicate eqNullSafe1 = new Predicate("<=>", attrAndValue); + Expression expectedEqNullSafe1 = Expressions.equal(expectedTerm, value); + Expression actualEqNullSafe1 = SparkV2Filters.convert(eqNullSafe1); + assertEquals(expectedEqNullSafe1, actualEqNullSafe1); + + Predicate eqNullSafe2 = new Predicate("<=>", valueAndAttr); + Expression expectedEqNullSafe2 = Expressions.equal(expectedTerm, value); + Expression actualEqNullSafe2 = SparkV2Filters.convert(eqNullSafe2); + assertEquals(expectedEqNullSafe2, actualEqNullSafe2); + + Predicate in = new Predicate("IN", attrAndValue); + Expression expectedIn = Expressions.in(expectedTerm, value); + Expression actualIn = SparkV2Filters.convert(in); + assertEquals(expectedIn, actualIn); + + Predicate notIn = new Not(in); + Expression expectedNotIn = + Expressions.and(Expressions.notNull(expectedTerm), Expressions.notIn(expectedTerm, value)); + Expression actualNotIn = SparkV2Filters.convert(notIn); + assertEquals(expectedNotIn, actualNotIn); + + Predicate and = new And(lt1, eq1); + Expression expectedAnd = Expressions.and(expectedLt1, expectedEq1); + Expression actualAnd = SparkV2Filters.convert(and); + assertEquals(expectedAnd, actualAnd); + + org.apache.spark.sql.connector.expressions.Expression[] attrAndAttr = expressions(udf, udf); + Predicate invalid = new Predicate("<", attrAndAttr); + Predicate andWithInvalidLeft = new And(invalid, eq1); + Expression convertedAnd = SparkV2Filters.convert(andWithInvalidLeft); + assertThat(convertedAnd).isNull(); + + Predicate or = new Or(lt1, eq1); + Expression expectedOr = Expressions.or(expectedLt1, expectedEq1); + Expression actualOr = SparkV2Filters.convert(or); + assertEquals(expectedOr, actualOr); + + Predicate orWithInvalidLeft = new Or(invalid, eq1); + Expression convertedOr = SparkV2Filters.convert(orWithInvalidLeft); + assertThat(convertedOr).isNull(); + + Predicate not = new Not(lt1); + Expression expectedNot = Expressions.not(expectedLt1); + Expression actualNot = SparkV2Filters.convert(not); + assertEquals(expectedNot, actualNot); + } + + private static void assertEquals(Expression expected, Expression actual) { + assertThat(ExpressionUtil.equivalent(expected, actual, STRUCT, true)).isTrue(); + } + + private org.apache.spark.sql.connector.expressions.Expression[] expressions( + org.apache.spark.sql.connector.expressions.Expression... expressions) { + return expressions; + } + + private static int dateToYears(String dateString) { + return DateTimeUtil.daysToYears(DateTimeUtil.isoDateToDays(dateString)); + } + + private static int timestampToYears(String timestampString) { + return DateTimeUtil.microsToYears(DateTimeUtil.isoTimestamptzToMicros(timestampString)); + } + + private static int timestampNtzToYears(String timestampNtzString) { + return DateTimeUtil.microsToYears(DateTimeUtil.isoTimestampToMicros(timestampNtzString)); + } + + private static int dateToMonths(String dateString) { + return DateTimeUtil.daysToMonths(DateTimeUtil.isoDateToDays(dateString)); + } + + private static int timestampToMonths(String timestampString) { + return DateTimeUtil.microsToMonths(DateTimeUtil.isoTimestamptzToMicros(timestampString)); + } + + private static int timestampNtzToMonths(String timestampNtzString) { + return DateTimeUtil.microsToMonths(DateTimeUtil.isoTimestampToMicros(timestampNtzString)); + } + + private static int dateToDays(String dateString) { + return DateTimeUtil.isoDateToDays(dateString); + } + + private static int timestampToDays(String timestampString) { + return DateTimeUtil.microsToDays(DateTimeUtil.isoTimestamptzToMicros(timestampString)); + } + + private static int timestampNtzToDays(String timestampNtzString) { + return DateTimeUtil.microsToDays(DateTimeUtil.isoTimestampToMicros(timestampNtzString)); + } + + private static int timestampToHours(String timestampString) { + return DateTimeUtil.microsToHours(DateTimeUtil.isoTimestamptzToMicros(timestampString)); + } + + private static int timestampNtzToHours(String timestampNtzString) { + return DateTimeUtil.microsToHours(DateTimeUtil.isoTimestampToMicros(timestampNtzString)); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkValueConverter.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkValueConverter.java new file mode 100644 index 000000000000..c7a2e6c18fca --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkValueConverter.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.apache.iceberg.Schema; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.junit.jupiter.api.Test; + +public class TestSparkValueConverter { + @Test + public void testSparkNullMapConvert() { + Schema schema = + new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional( + 5, + "locations", + Types.MapType.ofOptional( + 6, + 7, + Types.StringType.get(), + Types.StructType.of( + Types.NestedField.required(1, "lat", Types.FloatType.get()), + Types.NestedField.required(2, "long", Types.FloatType.get()))))); + + assertCorrectNullConversion(schema); + } + + @Test + public void testSparkNullListConvert() { + Schema schema = + new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional( + 5, "locations", Types.ListType.ofOptional(6, Types.StringType.get()))); + + assertCorrectNullConversion(schema); + } + + @Test + public void testSparkNullStructConvert() { + Schema schema = + new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional( + 5, + "location", + Types.StructType.of( + Types.NestedField.required(1, "lat", Types.FloatType.get()), + Types.NestedField.required(2, "long", Types.FloatType.get())))); + + assertCorrectNullConversion(schema); + } + + @Test + public void testSparkNullPrimitiveConvert() { + Schema schema = + new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(5, "location", Types.StringType.get())); + assertCorrectNullConversion(schema); + } + + private void assertCorrectNullConversion(Schema schema) { + Row sparkRow = RowFactory.create(1, null); + Record record = GenericRecord.create(schema); + record.set(0, 1); + assertThat(SparkValueConverter.convert(schema, sparkRow)) + .as("Round-trip conversion should produce original value") + .isEqualTo(record); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java new file mode 100644 index 000000000000..c2df62697882 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java @@ -0,0 +1,583 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.apache.iceberg.TableProperties.AVRO_COMPRESSION; +import static org.apache.iceberg.TableProperties.AVRO_COMPRESSION_LEVEL; +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.DELETE_AVRO_COMPRESSION; +import static org.apache.iceberg.TableProperties.DELETE_AVRO_COMPRESSION_LEVEL; +import static org.apache.iceberg.TableProperties.DELETE_DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.DELETE_DISTRIBUTION_MODE; +import static org.apache.iceberg.TableProperties.DELETE_ORC_COMPRESSION; +import static org.apache.iceberg.TableProperties.DELETE_ORC_COMPRESSION_STRATEGY; +import static org.apache.iceberg.TableProperties.DELETE_PARQUET_COMPRESSION; +import static org.apache.iceberg.TableProperties.DELETE_PARQUET_COMPRESSION_LEVEL; +import static org.apache.iceberg.TableProperties.MERGE_DISTRIBUTION_MODE; +import static org.apache.iceberg.TableProperties.ORC_COMPRESSION; +import static org.apache.iceberg.TableProperties.ORC_COMPRESSION_STRATEGY; +import static org.apache.iceberg.TableProperties.PARQUET_COMPRESSION; +import static org.apache.iceberg.TableProperties.PARQUET_COMPRESSION_LEVEL; +import static org.apache.iceberg.TableProperties.UPDATE_DISTRIBUTION_MODE; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_HASH; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_NONE; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_RANGE; +import static org.apache.iceberg.spark.SparkSQLProperties.COMPRESSION_CODEC; +import static org.apache.iceberg.spark.SparkSQLProperties.COMPRESSION_LEVEL; +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE; +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE; +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPDATE; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.DistributionMode; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.UpdateProperties; +import org.apache.iceberg.deletes.DeleteGranularity; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.spark.sql.internal.SQLConf; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestSparkWriteConf extends TestBaseWithCatalog { + + @BeforeEach + public void before() { + super.before(); + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + } + + @AfterEach + public void after() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testOptionCaseInsensitive() { + Table table = validationCatalog.loadTable(tableIdent); + Map options = ImmutableMap.of("option", "value"); + SparkConfParser parser = new SparkConfParser(spark, table, options); + String parsedValue = parser.stringConf().option("oPtIoN").parseOptional(); + assertThat(parsedValue).isEqualTo("value"); + } + + @TestTemplate + public void testCamelCaseSparkSessionConf() { + Table table = validationCatalog.loadTable(tableIdent); + String confName = "spark.sql.iceberg.some-int-conf"; + String sparkConfName = "spark.sql.iceberg.someIntConf"; + + withSQLConf( + ImmutableMap.of(sparkConfName, "1"), + () -> { + SparkConfParser parser = new SparkConfParser(spark, table, ImmutableMap.of()); + Integer value = parser.intConf().sessionConf(confName).parseOptional(); + assertThat(value).isEqualTo(1); + }); + } + + @TestTemplate + public void testCamelCaseSparkOption() { + Table table = validationCatalog.loadTable(tableIdent); + String option = "some-int-option"; + String sparkOption = "someIntOption"; + Map options = ImmutableMap.of(sparkOption, "1"); + SparkConfParser parser = new SparkConfParser(spark, table, options); + Integer value = parser.intConf().option(option).parseOptional(); + assertThat(value).isEqualTo(1); + } + + @TestTemplate + public void testDurationConf() { + Table table = validationCatalog.loadTable(tableIdent); + String confName = "spark.sql.iceberg.some-duration-conf"; + + withSQLConf( + ImmutableMap.of(confName, "10s"), + () -> { + SparkConfParser parser = new SparkConfParser(spark, table, ImmutableMap.of()); + Duration duration = parser.durationConf().sessionConf(confName).parseOptional(); + assertThat(duration).hasSeconds(10); + }); + + withSQLConf( + ImmutableMap.of(confName, "2m"), + () -> { + SparkConfParser parser = new SparkConfParser(spark, table, ImmutableMap.of()); + Duration duration = parser.durationConf().sessionConf(confName).parseOptional(); + assertThat(duration).hasMinutes(2); + }); + } + + @TestTemplate + public void testDeleteGranularityDefault() { + Table table = validationCatalog.loadTable(tableIdent); + SparkWriteConf writeConf = new SparkWriteConf(spark, table, ImmutableMap.of()); + + DeleteGranularity value = writeConf.deleteGranularity(); + assertThat(value).isEqualTo(DeleteGranularity.PARTITION); + } + + @TestTemplate + public void testDeleteGranularityTableProperty() { + Table table = validationCatalog.loadTable(tableIdent); + + table + .updateProperties() + .set(TableProperties.DELETE_GRANULARITY, DeleteGranularity.FILE.toString()) + .commit(); + + SparkWriteConf writeConf = new SparkWriteConf(spark, table, ImmutableMap.of()); + + DeleteGranularity value = writeConf.deleteGranularity(); + assertThat(value).isEqualTo(DeleteGranularity.FILE); + } + + @TestTemplate + public void testDeleteGranularityWriteOption() { + Table table = validationCatalog.loadTable(tableIdent); + + table + .updateProperties() + .set(TableProperties.DELETE_GRANULARITY, DeleteGranularity.PARTITION.toString()) + .commit(); + + Map options = + ImmutableMap.of(SparkWriteOptions.DELETE_GRANULARITY, DeleteGranularity.FILE.toString()); + + SparkWriteConf writeConf = new SparkWriteConf(spark, table, options); + + DeleteGranularity value = writeConf.deleteGranularity(); + assertThat(value).isEqualTo(DeleteGranularity.FILE); + } + + @TestTemplate + public void testDeleteGranularityInvalidValue() { + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(TableProperties.DELETE_GRANULARITY, "invalid").commit(); + + SparkWriteConf writeConf = new SparkWriteConf(spark, table, ImmutableMap.of()); + + assertThatThrownBy(writeConf::deleteGranularity) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Unknown delete granularity"); + } + + @TestTemplate + public void testAdvisoryPartitionSize() { + Table table = validationCatalog.loadTable(tableIdent); + + SparkWriteConf writeConf = new SparkWriteConf(spark, table, ImmutableMap.of()); + + long value1 = writeConf.writeRequirements().advisoryPartitionSize(); + assertThat(value1).isGreaterThan(64L * 1024 * 1024).isLessThan(2L * 1024 * 1024 * 1024); + + spark.conf().set(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES().key(), "2GB"); + long value2 = writeConf.writeRequirements().advisoryPartitionSize(); + assertThat(value2).isEqualTo(2L * 1024 * 1024 * 1024); + + spark.conf().set(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES().key(), "10MB"); + long value3 = writeConf.writeRequirements().advisoryPartitionSize(); + assertThat(value3).isGreaterThan(10L * 1024 * 1024); + } + + @TestTemplate + public void testSparkWriteConfDistributionDefault() { + Table table = validationCatalog.loadTable(tableIdent); + + SparkWriteConf writeConf = new SparkWriteConf(spark, table, ImmutableMap.of()); + + checkMode(DistributionMode.HASH, writeConf); + } + + @TestTemplate + public void testSparkWriteConfDistributionModeWithWriteOption() { + Table table = validationCatalog.loadTable(tableIdent); + + Map writeOptions = + ImmutableMap.of(SparkWriteOptions.DISTRIBUTION_MODE, DistributionMode.NONE.modeName()); + + SparkWriteConf writeConf = new SparkWriteConf(spark, table, writeOptions); + checkMode(DistributionMode.NONE, writeConf); + } + + @TestTemplate + public void testSparkWriteConfDistributionModeWithSessionConfig() { + withSQLConf( + ImmutableMap.of(SparkSQLProperties.DISTRIBUTION_MODE, DistributionMode.NONE.modeName()), + () -> { + Table table = validationCatalog.loadTable(tableIdent); + SparkWriteConf writeConf = new SparkWriteConf(spark, table, ImmutableMap.of()); + checkMode(DistributionMode.NONE, writeConf); + }); + } + + @TestTemplate + public void testSparkWriteConfDistributionModeWithTableProperties() { + Table table = validationCatalog.loadTable(tableIdent); + + table + .updateProperties() + .set(WRITE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE) + .set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE) + .set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE) + .set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE) + .commit(); + + SparkWriteConf writeConf = new SparkWriteConf(spark, table, ImmutableMap.of()); + checkMode(DistributionMode.NONE, writeConf); + } + + @TestTemplate + public void testSparkWriteConfDistributionModeWithTblPropAndSessionConfig() { + withSQLConf( + ImmutableMap.of(SparkSQLProperties.DISTRIBUTION_MODE, DistributionMode.NONE.modeName()), + () -> { + Table table = validationCatalog.loadTable(tableIdent); + + table + .updateProperties() + .set(WRITE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE) + .set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE) + .set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE) + .set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE) + .commit(); + + SparkWriteConf writeConf = new SparkWriteConf(spark, table, ImmutableMap.of()); + // session config overwrite the table properties + checkMode(DistributionMode.NONE, writeConf); + }); + } + + @TestTemplate + public void testSparkWriteConfDistributionModeWithWriteOptionAndSessionConfig() { + withSQLConf( + ImmutableMap.of(SparkSQLProperties.DISTRIBUTION_MODE, DistributionMode.RANGE.modeName()), + () -> { + Table table = validationCatalog.loadTable(tableIdent); + + Map writeOptions = + ImmutableMap.of( + SparkWriteOptions.DISTRIBUTION_MODE, DistributionMode.NONE.modeName()); + + SparkWriteConf writeConf = new SparkWriteConf(spark, table, writeOptions); + // write options overwrite the session config + checkMode(DistributionMode.NONE, writeConf); + }); + } + + @TestTemplate + public void testSparkWriteConfDistributionModeWithEverything() { + withSQLConf( + ImmutableMap.of(SparkSQLProperties.DISTRIBUTION_MODE, DistributionMode.RANGE.modeName()), + () -> { + Table table = validationCatalog.loadTable(tableIdent); + + Map writeOptions = + ImmutableMap.of( + SparkWriteOptions.DISTRIBUTION_MODE, DistributionMode.NONE.modeName()); + + table + .updateProperties() + .set(WRITE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH) + .set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH) + .set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH) + .set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH) + .commit(); + + SparkWriteConf writeConf = new SparkWriteConf(spark, table, writeOptions); + // write options take the highest priority + checkMode(DistributionMode.NONE, writeConf); + }); + } + + @TestTemplate + public void testSparkConfOverride() { + List>> propertiesSuites = + Lists.newArrayList( + Lists.newArrayList( + ImmutableMap.of(COMPRESSION_CODEC, "zstd", COMPRESSION_LEVEL, "3"), + ImmutableMap.of( + DEFAULT_FILE_FORMAT, + "parquet", + DELETE_DEFAULT_FILE_FORMAT, + "parquet", + TableProperties.PARQUET_COMPRESSION, + "gzip", + TableProperties.DELETE_PARQUET_COMPRESSION, + "snappy"), + ImmutableMap.of( + DELETE_PARQUET_COMPRESSION, + "zstd", + PARQUET_COMPRESSION, + "zstd", + PARQUET_COMPRESSION_LEVEL, + "3", + DELETE_PARQUET_COMPRESSION_LEVEL, + "3")), + Lists.newArrayList( + ImmutableMap.of( + COMPRESSION_CODEC, + "zstd", + SparkSQLProperties.COMPRESSION_STRATEGY, + "compression"), + ImmutableMap.of( + DEFAULT_FILE_FORMAT, + "orc", + DELETE_DEFAULT_FILE_FORMAT, + "orc", + ORC_COMPRESSION, + "zlib", + DELETE_ORC_COMPRESSION, + "snappy"), + ImmutableMap.of( + DELETE_ORC_COMPRESSION, + "zstd", + ORC_COMPRESSION, + "zstd", + DELETE_ORC_COMPRESSION_STRATEGY, + "compression", + ORC_COMPRESSION_STRATEGY, + "compression")), + Lists.newArrayList( + ImmutableMap.of(COMPRESSION_CODEC, "zstd", COMPRESSION_LEVEL, "9"), + ImmutableMap.of( + DEFAULT_FILE_FORMAT, + "avro", + DELETE_DEFAULT_FILE_FORMAT, + "avro", + AVRO_COMPRESSION, + "gzip", + DELETE_AVRO_COMPRESSION, + "snappy"), + ImmutableMap.of( + DELETE_AVRO_COMPRESSION, + "zstd", + AVRO_COMPRESSION, + "zstd", + AVRO_COMPRESSION_LEVEL, + "9", + DELETE_AVRO_COMPRESSION_LEVEL, + "9"))); + for (List> propertiesSuite : propertiesSuites) { + testWriteProperties(propertiesSuite); + } + } + + @TestTemplate + public void testDataPropsDefaultsAsDeleteProps() { + List>> propertiesSuites = + Lists.newArrayList( + Lists.newArrayList( + ImmutableMap.of(), + ImmutableMap.of( + DEFAULT_FILE_FORMAT, + "parquet", + DELETE_DEFAULT_FILE_FORMAT, + "parquet", + PARQUET_COMPRESSION, + "zstd", + PARQUET_COMPRESSION_LEVEL, + "5"), + ImmutableMap.of( + DELETE_PARQUET_COMPRESSION, + "zstd", + PARQUET_COMPRESSION, + "zstd", + PARQUET_COMPRESSION_LEVEL, + "5", + DELETE_PARQUET_COMPRESSION_LEVEL, + "5")), + Lists.newArrayList( + ImmutableMap.of(), + ImmutableMap.of( + DEFAULT_FILE_FORMAT, + "orc", + DELETE_DEFAULT_FILE_FORMAT, + "orc", + ORC_COMPRESSION, + "snappy", + ORC_COMPRESSION_STRATEGY, + "speed"), + ImmutableMap.of( + DELETE_ORC_COMPRESSION, + "snappy", + ORC_COMPRESSION, + "snappy", + ORC_COMPRESSION_STRATEGY, + "speed", + DELETE_ORC_COMPRESSION_STRATEGY, + "speed")), + Lists.newArrayList( + ImmutableMap.of(), + ImmutableMap.of( + DEFAULT_FILE_FORMAT, + "avro", + DELETE_DEFAULT_FILE_FORMAT, + "avro", + AVRO_COMPRESSION, + "snappy", + AVRO_COMPRESSION_LEVEL, + "9"), + ImmutableMap.of( + DELETE_AVRO_COMPRESSION, + "snappy", + AVRO_COMPRESSION, + "snappy", + AVRO_COMPRESSION_LEVEL, + "9", + DELETE_AVRO_COMPRESSION_LEVEL, + "9"))); + for (List> propertiesSuite : propertiesSuites) { + testWriteProperties(propertiesSuite); + } + } + + @TestTemplate + public void testDeleteFileWriteConf() { + List>> propertiesSuites = + Lists.newArrayList( + Lists.newArrayList( + ImmutableMap.of(), + ImmutableMap.of( + DEFAULT_FILE_FORMAT, + "parquet", + DELETE_DEFAULT_FILE_FORMAT, + "parquet", + TableProperties.PARQUET_COMPRESSION, + "zstd", + PARQUET_COMPRESSION_LEVEL, + "5", + DELETE_PARQUET_COMPRESSION_LEVEL, + "6"), + ImmutableMap.of( + DELETE_PARQUET_COMPRESSION, + "zstd", + PARQUET_COMPRESSION, + "zstd", + PARQUET_COMPRESSION_LEVEL, + "5", + DELETE_PARQUET_COMPRESSION_LEVEL, + "6")), + Lists.newArrayList( + ImmutableMap.of(), + ImmutableMap.of( + DEFAULT_FILE_FORMAT, + "orc", + DELETE_DEFAULT_FILE_FORMAT, + "orc", + ORC_COMPRESSION, + "snappy", + ORC_COMPRESSION_STRATEGY, + "speed", + DELETE_ORC_COMPRESSION, + "zstd", + DELETE_ORC_COMPRESSION_STRATEGY, + "compression"), + ImmutableMap.of( + DELETE_ORC_COMPRESSION, + "zstd", + ORC_COMPRESSION, + "snappy", + ORC_COMPRESSION_STRATEGY, + "speed", + DELETE_ORC_COMPRESSION_STRATEGY, + "compression")), + Lists.newArrayList( + ImmutableMap.of(), + ImmutableMap.of( + DEFAULT_FILE_FORMAT, + "avro", + DELETE_DEFAULT_FILE_FORMAT, + "avro", + AVRO_COMPRESSION, + "snappy", + AVRO_COMPRESSION_LEVEL, + "9", + DELETE_AVRO_COMPRESSION, + "zstd", + DELETE_AVRO_COMPRESSION_LEVEL, + "16"), + ImmutableMap.of( + DELETE_AVRO_COMPRESSION, + "zstd", + AVRO_COMPRESSION, + "snappy", + AVRO_COMPRESSION_LEVEL, + "9", + DELETE_AVRO_COMPRESSION_LEVEL, + "16"))); + for (List> propertiesSuite : propertiesSuites) { + testWriteProperties(propertiesSuite); + } + } + + private void testWriteProperties(List> propertiesSuite) { + withSQLConf( + propertiesSuite.get(0), + () -> { + Table table = validationCatalog.loadTable(tableIdent); + Map tableProperties = propertiesSuite.get(1); + UpdateProperties updateProperties = table.updateProperties(); + for (Map.Entry entry : tableProperties.entrySet()) { + updateProperties.set(entry.getKey(), entry.getValue()); + } + + updateProperties.commit(); + + SparkWriteConf writeConf = new SparkWriteConf(spark, table, ImmutableMap.of()); + Map writeProperties = writeConf.writeProperties(); + Map expectedProperties = propertiesSuite.get(2); + assertThat(writeConf.writeProperties()).hasSameSizeAs(expectedProperties); + for (Map.Entry entry : writeProperties.entrySet()) { + assertThat(expectedProperties).containsEntry(entry.getKey(), entry.getValue()); + } + + table.refresh(); + updateProperties = table.updateProperties(); + for (Map.Entry entry : tableProperties.entrySet()) { + updateProperties.remove(entry.getKey()); + } + + updateProperties.commit(); + }); + } + + private void checkMode(DistributionMode expectedMode, SparkWriteConf writeConf) { + assertThat(writeConf.distributionMode()).isEqualTo(expectedMode); + assertThat(writeConf.copyOnWriteDistributionMode(DELETE)).isEqualTo(expectedMode); + assertThat(writeConf.positionDeltaDistributionMode(DELETE)).isEqualTo(expectedMode); + assertThat(writeConf.copyOnWriteDistributionMode(UPDATE)).isEqualTo(expectedMode); + assertThat(writeConf.positionDeltaDistributionMode(UPDATE)).isEqualTo(expectedMode); + assertThat(writeConf.copyOnWriteDistributionMode(MERGE)).isEqualTo(expectedMode); + assertThat(writeConf.positionDeltaDistributionMode(MERGE)).isEqualTo(expectedMode); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestComputeTableStatsAction.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestComputeTableStatsAction.java new file mode 100644 index 000000000000..7aa849d0bba8 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestComputeTableStatsAction.java @@ -0,0 +1,406 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.apache.iceberg.spark.actions.NDVSketchUtil.APACHE_DATASKETCHES_THETA_V1_NDV_PROPERTY; +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNoException; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.io.IOException; +import java.util.List; +import org.apache.iceberg.BlobMetadata; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.Files; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.StatisticsFile; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.ComputeTableStats; +import org.apache.iceberg.data.FileHelpers; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.CatalogTestBase; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.iceberg.types.Types; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.apache.spark.sql.types.StructType; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestComputeTableStatsAction extends CatalogTestBase { + + private static final Types.StructType LEAF_STRUCT_TYPE = + Types.StructType.of( + optional(1, "leafLongCol", Types.LongType.get()), + optional(2, "leafDoubleCol", Types.DoubleType.get())); + + private static final Types.StructType NESTED_STRUCT_TYPE = + Types.StructType.of(required(3, "leafStructCol", LEAF_STRUCT_TYPE)); + + private static final Schema NESTED_SCHEMA = + new Schema(required(4, "nestedStructCol", NESTED_STRUCT_TYPE)); + + private static final Schema SCHEMA_WITH_NESTED_COLUMN = + new Schema( + required(4, "nestedStructCol", NESTED_STRUCT_TYPE), + required(5, "stringCol", Types.StringType.get())); + + @TestTemplate + public void testLoadingTableDirectly() { + sql("CREATE TABLE %s (id int, data string) USING iceberg", tableName); + sql("INSERT into %s values(1, 'abcd')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + SparkActions actions = SparkActions.get(); + ComputeTableStats.Result results = actions.computeTableStats(table).execute(); + StatisticsFile statisticsFile = results.statisticsFile(); + assertThat(statisticsFile.fileSizeInBytes()).isGreaterThan(0); + assertThat(statisticsFile.blobMetadata()).hasSize(2); + } + + @TestTemplate + public void testComputeTableStatsAction() throws NoSuchTableException, ParseException { + sql("CREATE TABLE %s (id int, data string) USING iceberg", tableName); + Table table = Spark3Util.loadIcebergTable(spark, tableName); + + // To create multiple splits on the mapper + table + .updateProperties() + .set("read.split.target-size", "100") + .set("write.parquet.row-group-size-bytes", "100") + .commit(); + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c"), + new SimpleRecord(4, "d")); + spark.createDataset(records, Encoders.bean(SimpleRecord.class)).writeTo(tableName).append(); + SparkActions actions = SparkActions.get(); + ComputeTableStats.Result results = + actions.computeTableStats(table).columns("id", "data").execute(); + assertThat(results).isNotNull(); + + List statisticsFiles = table.statisticsFiles(); + assertThat(statisticsFiles).hasSize(1); + + StatisticsFile statisticsFile = statisticsFiles.get(0); + assertThat(statisticsFile.fileSizeInBytes()).isGreaterThan(0); + assertThat(statisticsFile.blobMetadata()).hasSize(2); + + BlobMetadata blobMetadata = statisticsFile.blobMetadata().get(0); + assertThat(blobMetadata.properties()) + .containsEntry(APACHE_DATASKETCHES_THETA_V1_NDV_PROPERTY, "4"); + } + + @TestTemplate + public void testComputeTableStatsActionWithoutExplicitColumns() + throws NoSuchTableException, ParseException { + sql("CREATE TABLE %s (id int, data string) USING iceberg", tableName); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c"), + new SimpleRecord(4, "d")); + spark + .createDataset(records, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + Table table = Spark3Util.loadIcebergTable(spark, tableName); + SparkActions actions = SparkActions.get(); + ComputeTableStats.Result results = actions.computeTableStats(table).execute(); + assertThat(results).isNotNull(); + + assertThat(table.statisticsFiles()).hasSize(1); + StatisticsFile statisticsFile = table.statisticsFiles().get(0); + assertThat(statisticsFile.fileSizeInBytes()).isGreaterThan(0); + assertThat(statisticsFile.blobMetadata()).hasSize(2); + assertThat(statisticsFile.blobMetadata().get(0).properties()) + .containsEntry(APACHE_DATASKETCHES_THETA_V1_NDV_PROPERTY, "4"); + assertThat(statisticsFile.blobMetadata().get(1).properties()) + .containsEntry(APACHE_DATASKETCHES_THETA_V1_NDV_PROPERTY, "4"); + } + + @TestTemplate + public void testComputeTableStatsForInvalidColumns() throws NoSuchTableException, ParseException { + sql("CREATE TABLE %s (id int, data string) USING iceberg", tableName); + // Append data to create snapshot + sql("INSERT into %s values(1, 'abcd')", tableName); + Table table = Spark3Util.loadIcebergTable(spark, tableName); + SparkActions actions = SparkActions.get(); + assertThatThrownBy(() -> actions.computeTableStats(table).columns("id1").execute()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageStartingWith("Can't find column id1 in table"); + } + + @TestTemplate + public void testComputeTableStatsWithNoSnapshots() throws NoSuchTableException, ParseException { + sql("CREATE TABLE %s (id int, data string) USING iceberg", tableName); + Table table = Spark3Util.loadIcebergTable(spark, tableName); + SparkActions actions = SparkActions.get(); + ComputeTableStats.Result result = actions.computeTableStats(table).columns("id").execute(); + assertThat(result.statisticsFile()).isNull(); + } + + @TestTemplate + public void testComputeTableStatsWithNullValues() throws NoSuchTableException, ParseException { + sql("CREATE TABLE %s (id int, data string) USING iceberg", tableName); + List records = + Lists.newArrayList( + new SimpleRecord(1, null), + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c"), + new SimpleRecord(4, "d")); + spark + .createDataset(records, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + Table table = Spark3Util.loadIcebergTable(spark, tableName); + SparkActions actions = SparkActions.get(); + ComputeTableStats.Result results = actions.computeTableStats(table).columns("data").execute(); + assertThat(results).isNotNull(); + + List statisticsFiles = table.statisticsFiles(); + assertThat(statisticsFiles).hasSize(1); + + StatisticsFile statisticsFile = statisticsFiles.get(0); + assertThat(statisticsFile.fileSizeInBytes()).isGreaterThan(0); + assertThat(statisticsFile.blobMetadata()).hasSize(1); + + assertThat(statisticsFile.blobMetadata().get(0).properties()) + .containsEntry(APACHE_DATASKETCHES_THETA_V1_NDV_PROPERTY, "4"); + } + + @TestTemplate + public void testComputeTableStatsWithSnapshotHavingDifferentSchemas() + throws NoSuchTableException, ParseException { + SparkActions actions = SparkActions.get(); + sql("CREATE TABLE %s (id int, data string) USING iceberg", tableName); + // Append data to create snapshot + sql("INSERT into %s values(1, 'abcd')", tableName); + long snapshotId1 = Spark3Util.loadIcebergTable(spark, tableName).currentSnapshot().snapshotId(); + // Snapshot id not specified + Table table = Spark3Util.loadIcebergTable(spark, tableName); + + assertThatNoException() + .isThrownBy(() -> actions.computeTableStats(table).columns("data").execute()); + + sql("ALTER TABLE %s DROP COLUMN %s", tableName, "data"); + // Append data to create snapshot + sql("INSERT into %s values(1)", tableName); + table.refresh(); + long snapshotId2 = Spark3Util.loadIcebergTable(spark, tableName).currentSnapshot().snapshotId(); + + // Snapshot id specified + assertThatNoException() + .isThrownBy( + () -> actions.computeTableStats(table).snapshot(snapshotId1).columns("data").execute()); + + assertThatThrownBy( + () -> actions.computeTableStats(table).snapshot(snapshotId2).columns("data").execute()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageStartingWith("Can't find column data in table"); + } + + @TestTemplate + public void testComputeTableStatsWhenSnapshotIdNotSpecified() + throws NoSuchTableException, ParseException { + sql("CREATE TABLE %s (id int, data string) USING iceberg", tableName); + // Append data to create snapshot + sql("INSERT into %s values(1, 'abcd')", tableName); + Table table = Spark3Util.loadIcebergTable(spark, tableName); + SparkActions actions = SparkActions.get(); + ComputeTableStats.Result results = actions.computeTableStats(table).columns("data").execute(); + + assertThat(results).isNotNull(); + + List statisticsFiles = table.statisticsFiles(); + assertThat(statisticsFiles).hasSize(1); + + StatisticsFile statisticsFile = statisticsFiles.get(0); + assertThat(statisticsFile.fileSizeInBytes()).isGreaterThan(0); + assertThat(statisticsFile.blobMetadata()).hasSize(1); + + assertThat(statisticsFile.blobMetadata().get(0).properties()) + .containsEntry(APACHE_DATASKETCHES_THETA_V1_NDV_PROPERTY, "1"); + } + + @TestTemplate + public void testComputeTableStatsWithNestedSchema() + throws NoSuchTableException, ParseException, IOException { + List records = Lists.newArrayList(createNestedRecord()); + Table table = + validationCatalog.createTable( + tableIdent, + SCHEMA_WITH_NESTED_COLUMN, + PartitionSpec.unpartitioned(), + ImmutableMap.of()); + DataFile dataFile = FileHelpers.writeDataFile(table, Files.localOutput(temp.toFile()), records); + table.newAppend().appendFile(dataFile).commit(); + + Table tbl = Spark3Util.loadIcebergTable(spark, tableName); + SparkActions actions = SparkActions.get(); + actions.computeTableStats(tbl).execute(); + + tbl.refresh(); + List statisticsFiles = tbl.statisticsFiles(); + assertThat(statisticsFiles).hasSize(1); + StatisticsFile statisticsFile = statisticsFiles.get(0); + assertThat(statisticsFile.fileSizeInBytes()).isGreaterThan(0); + assertThat(statisticsFile.blobMetadata()).hasSize(1); + } + + @TestTemplate + public void testComputeTableStatsWithNoComputableColumns() throws IOException { + List records = Lists.newArrayList(createNestedRecord()); + Table table = + validationCatalog.createTable( + tableIdent, NESTED_SCHEMA, PartitionSpec.unpartitioned(), ImmutableMap.of()); + DataFile dataFile = FileHelpers.writeDataFile(table, Files.localOutput(temp.toFile()), records); + table.newAppend().appendFile(dataFile).commit(); + + table.refresh(); + SparkActions actions = SparkActions.get(); + assertThatThrownBy(() -> actions.computeTableStats(table).execute()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("No columns found to compute stats"); + } + + @TestTemplate + public void testComputeTableStatsOnByteColumn() throws NoSuchTableException, ParseException { + testComputeTableStats("byte_col", "TINYINT"); + } + + @TestTemplate + public void testComputeTableStatsOnShortColumn() throws NoSuchTableException, ParseException { + testComputeTableStats("short_col", "SMALLINT"); + } + + @TestTemplate + public void testComputeTableStatsOnIntColumn() throws NoSuchTableException, ParseException { + testComputeTableStats("int_col", "INT"); + } + + @TestTemplate + public void testComputeTableStatsOnLongColumn() throws NoSuchTableException, ParseException { + testComputeTableStats("long_col", "BIGINT"); + } + + @TestTemplate + public void testComputeTableStatsOnTimestampColumn() throws NoSuchTableException, ParseException { + testComputeTableStats("timestamp_col", "TIMESTAMP"); + } + + @TestTemplate + public void testComputeTableStatsOnTimestampNtzColumn() + throws NoSuchTableException, ParseException { + testComputeTableStats("timestamp_col", "TIMESTAMP_NTZ"); + } + + @TestTemplate + public void testComputeTableStatsOnDateColumn() throws NoSuchTableException, ParseException { + testComputeTableStats("date_col", "DATE"); + } + + @TestTemplate + public void testComputeTableStatsOnDecimalColumn() throws NoSuchTableException, ParseException { + testComputeTableStats("decimal_col", "DECIMAL(20, 2)"); + } + + @TestTemplate + public void testComputeTableStatsOnBinaryColumn() throws NoSuchTableException, ParseException { + testComputeTableStats("binary_col", "BINARY"); + } + + public void testComputeTableStats(String columnName, String type) + throws NoSuchTableException, ParseException { + sql("CREATE TABLE %s (id int, %s %s) USING iceberg", tableName, columnName, type); + Table table = Spark3Util.loadIcebergTable(spark, tableName); + + Dataset dataDF = randomDataDF(table.schema()); + append(tableName, dataDF); + + SparkActions actions = SparkActions.get(); + table.refresh(); + ComputeTableStats.Result results = + actions.computeTableStats(table).columns(columnName).execute(); + assertThat(results).isNotNull(); + + List statisticsFiles = table.statisticsFiles(); + assertThat(statisticsFiles).hasSize(1); + + StatisticsFile statisticsFile = statisticsFiles.get(0); + assertThat(statisticsFile.fileSizeInBytes()).isGreaterThan(0); + assertThat(statisticsFile.blobMetadata()).hasSize(1); + + assertThat(statisticsFile.blobMetadata().get(0).properties()) + .containsKey(APACHE_DATASKETCHES_THETA_V1_NDV_PROPERTY); + } + + private GenericRecord createNestedRecord() { + GenericRecord record = GenericRecord.create(SCHEMA_WITH_NESTED_COLUMN); + GenericRecord nested = GenericRecord.create(NESTED_STRUCT_TYPE); + GenericRecord leaf = GenericRecord.create(LEAF_STRUCT_TYPE); + leaf.set(0, 0L); + leaf.set(1, 0.0); + nested.set(0, leaf); + record.set(0, nested); + record.set(1, "data"); + return record; + } + + private Dataset randomDataDF(Schema schema) { + Iterable rows = RandomData.generateSpark(schema, 10, 0); + JavaRDD rowRDD = sparkContext.parallelize(Lists.newArrayList(rows)); + StructType rowSparkType = SparkSchemaUtil.convert(schema); + return spark.internalCreateDataFrame(JavaRDD.toRDD(rowRDD), rowSparkType, false); + } + + private void append(String table, Dataset df) throws NoSuchTableException { + // fanout writes are enabled as write-time clustering is not supported without Spark extensions + df.coalesce(1).writeTo(table).option(SparkWriteOptions.FANOUT_ENABLED, "true").append(); + } + + @AfterEach + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestCreateActions.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestCreateActions.java new file mode 100644 index 000000000000..6954903b4102 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestCreateActions.java @@ -0,0 +1,1076 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assumptions.assumeThat; + +import java.io.File; +import java.io.FilenameFilter; +import java.io.IOException; +import java.net.URI; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.commons.io.FileUtils; +import org.apache.commons.io.filefilter.TrueFileFilter; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.actions.MigrateTable; +import org.apache.iceberg.actions.SnapshotTable; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.CatalogTestBase; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.iceberg.spark.source.SparkTable; +import org.apache.iceberg.types.Types; +import org.apache.parquet.hadoop.ParquetFileReader; +import org.apache.parquet.hadoop.util.HadoopInputFile; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.MessageTypeParser; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.catalyst.TableIdentifier; +import org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.catalog.CatalogTable; +import org.apache.spark.sql.catalyst.catalog.CatalogTablePartition; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.io.TempDir; +import scala.Option; +import scala.Some; +import scala.collection.JavaConverters; +import scala.collection.Seq; + +public class TestCreateActions extends CatalogTestBase { + private static final String CREATE_PARTITIONED_PARQUET = + "CREATE TABLE %s (id INT, data STRING) " + "using parquet PARTITIONED BY (id) LOCATION '%s'"; + private static final String CREATE_PARQUET = + "CREATE TABLE %s (id INT, data STRING) " + "using parquet LOCATION '%s'"; + private static final String CREATE_HIVE_EXTERNAL_PARQUET = + "CREATE EXTERNAL TABLE %s (data STRING) " + + "PARTITIONED BY (id INT) STORED AS parquet LOCATION '%s'"; + private static final String CREATE_HIVE_PARQUET = + "CREATE TABLE %s (data STRING) " + "PARTITIONED BY (id INT) STORED AS parquet"; + + private static final String NAMESPACE = "default"; + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}, type = {3}") + public static Object[][] parameters() { + return new Object[][] { + new Object[] { + "spark_catalog", + SparkSessionCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "parquet-enabled", "true", + "cache-enabled", + "false" // Spark will delete tables using v1, leaving the cache out of sync + ), + "hive" + }, + new Object[] { + "spark_catalog", + SparkSessionCatalog.class.getName(), + ImmutableMap.of( + "type", "hadoop", + "default-namespace", "default", + "parquet-enabled", "true", + "cache-enabled", + "false" // Spark will delete tables using v1, leaving the cache out of sync + ), + "hadoop" + }, + new Object[] { + "testhive", + SparkCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default"), + "hive" + }, + new Object[] { + "testhadoop", + SparkCatalog.class.getName(), + ImmutableMap.of( + "type", "hadoop", + "default-namespace", "default"), + "hadoop" + } + }; + } + + private final String baseTableName = "baseTable"; + @TempDir private File tableDir; + private String tableLocation; + + @Parameter(index = 3) + private String type; + + private TableCatalog catalog; + + @BeforeEach + @Override + public void before() { + super.before(); + this.tableLocation = tableDir.toURI().toString(); + this.catalog = (TableCatalog) spark.sessionState().catalogManager().catalog(catalogName); + + spark.conf().set("hive.exec.dynamic.partition", "true"); + spark.conf().set("hive.exec.dynamic.partition.mode", "nonstrict"); + spark.conf().set("spark.sql.parquet.writeLegacyFormat", false); + spark.conf().set("spark.sql.parquet.writeLegacyFormat", false); + spark.sql(String.format("DROP TABLE IF EXISTS %s", baseTableName)); + + List expected = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + + Dataset df = spark.createDataFrame(expected, SimpleRecord.class); + + df.select("id", "data") + .orderBy("data") + .write() + .mode("append") + .option("path", tableLocation) + .saveAsTable(baseTableName); + } + + @AfterEach + public void after() throws IOException { + // Drop the hive table. + spark.sql(String.format("DROP TABLE IF EXISTS %s", baseTableName)); + spark.sessionState().catalogManager().reset(); + spark.conf().unset("spark.sql.catalog.spark_catalog.type"); + spark.conf().unset("spark.sql.catalog.spark_catalog.default-namespace"); + spark.conf().unset("spark.sql.catalog.spark_catalog.parquet-enabled"); + spark.conf().unset("spark.sql.catalog.spark_catalog.cache-enabled"); + } + + @TestTemplate + public void testMigratePartitioned() throws Exception { + assumeThat(type).as("Cannot migrate to a hadoop based catalog").isNotEqualTo("hadoop"); + assumeThat(catalog.name()) + .as("Can only migrate from Spark Session Catalog") + .isEqualTo("spark_catalog"); + String source = sourceName("test_migrate_partitioned_table"); + String dest = source; + createSourceTable(CREATE_PARTITIONED_PARQUET, source); + assertMigratedFileCount(SparkActions.get().migrateTable(source), source, dest); + } + + @TestTemplate + public void testPartitionedTableWithUnRecoveredPartitions() throws Exception { + assumeThat(type).as("Cannot migrate to a hadoop based catalog").isNotEqualTo("hadoop"); + assumeThat(catalog.name()) + .as("Can only migrate from Spark Session Catalog") + .isEqualTo("spark_catalog"); + String source = sourceName("test_unrecovered_partitions"); + String dest = source; + File location = Files.createTempDirectory(temp, "junit").toFile(); + sql(CREATE_PARTITIONED_PARQUET, source, location); + + // Data generation and partition addition + spark + .range(5) + .selectExpr("id", "cast(id as STRING) as data") + .write() + .partitionBy("id") + .mode(SaveMode.Overwrite) + .parquet(location.toURI().toString()); + sql("ALTER TABLE %s ADD PARTITION(id=0)", source); + + assertMigratedFileCount(SparkActions.get().migrateTable(source), source, dest); + } + + @TestTemplate + public void testPartitionedTableWithCustomPartitions() throws Exception { + assumeThat(type).as("Cannot migrate to a hadoop based catalog").isNotEqualTo("hadoop"); + assumeThat(catalog.name()) + .as("Can only migrate from Spark Session Catalog") + .isEqualTo("spark_catalog"); + String source = sourceName("test_custom_parts"); + String dest = source; + File tblLocation = Files.createTempDirectory(temp, "junit").toFile(); + File partitionDataLoc = Files.createTempDirectory(temp, "junit").toFile(); + + // Data generation and partition addition + spark.sql(String.format(CREATE_PARTITIONED_PARQUET, source, tblLocation)); + spark + .range(10) + .selectExpr("cast(id as STRING) as data") + .write() + .mode(SaveMode.Overwrite) + .parquet(partitionDataLoc.toURI().toString()); + sql( + "ALTER TABLE %s ADD PARTITION(id=0) LOCATION '%s'", + source, partitionDataLoc.toURI().toString()); + assertMigratedFileCount(SparkActions.get().migrateTable(source), source, dest); + } + + @TestTemplate + public void testAddColumnOnMigratedTableAtEnd() throws Exception { + assumeThat(type).as("Cannot migrate to a hadoop based catalog").isNotEqualTo("hadoop"); + assumeThat(catalog.name()) + .as("Can only migrate from Spark Session Catalog") + .isEqualTo("spark_catalog"); + String source = sourceName("test_add_column_migrated_table"); + String dest = source; + createSourceTable(CREATE_PARQUET, source); + List expected1 = sql("select *, null from %s order by id", source); + List expected2 = sql("select *, null, null from %s order by id", source); + + // migrate table + SparkActions.get().migrateTable(source).execute(); + SparkTable sparkTable = loadTable(dest); + Table table = sparkTable.table(); + + // test column addition on migrated table + Schema beforeSchema = table.schema(); + String newCol1 = "newCol1"; + sparkTable.table().updateSchema().addColumn(newCol1, Types.IntegerType.get()).commit(); + Schema afterSchema = table.schema(); + assertThat(beforeSchema.findField(newCol1)).isNull(); + assertThat(afterSchema.findField(newCol1)).isNotNull(); + + // reads should succeed without any exceptions + List results1 = sql("select * from %s order by id", dest); + assertThat(results1).isNotEmpty(); + assertEquals("Output must match", results1, expected1); + + String newCol2 = "newCol2"; + sql("ALTER TABLE %s ADD COLUMN %s INT", dest, newCol2); + StructType schema = spark.table(dest).schema(); + assertThat(schema.fieldNames()).contains(newCol2); + + // reads should succeed without any exceptions + List results2 = sql("select * from %s order by id", dest); + assertThat(results2).isNotEmpty(); + assertEquals("Output must match", results2, expected2); + } + + @TestTemplate + public void testAddColumnOnMigratedTableAtMiddle() throws Exception { + assumeThat(type).as("Cannot migrate to a hadoop based catalog").isNotEqualTo("hadoop"); + assumeThat(catalog.name()) + .as("Can only migrate from Spark Session Catalog") + .isEqualTo("spark_catalog"); + String source = sourceName("test_add_column_migrated_table_middle"); + String dest = source; + createSourceTable(CREATE_PARQUET, source); + + // migrate table + SparkActions.get().migrateTable(source).execute(); + SparkTable sparkTable = loadTable(dest); + Table table = sparkTable.table(); + List expected = sql("select id, null, data from %s order by id", source); + + // test column addition on migrated table + Schema beforeSchema = table.schema(); + String newCol1 = "newCol"; + sparkTable + .table() + .updateSchema() + .addColumn("newCol", Types.IntegerType.get()) + .moveAfter(newCol1, "id") + .commit(); + Schema afterSchema = table.schema(); + assertThat(beforeSchema.findField(newCol1)).isNull(); + assertThat(afterSchema.findField(newCol1)).isNotNull(); + + // reads should succeed + List results = sql("select * from %s order by id", dest); + assertThat(results).isNotEmpty(); + assertEquals("Output must match", results, expected); + } + + @TestTemplate + public void removeColumnsAtEnd() throws Exception { + assumeThat(type).as("Cannot migrate to a hadoop based catalog").isNotEqualTo("hadoop"); + assumeThat(catalog.name()) + .as("Can only migrate from Spark Session Catalog") + .isEqualTo("spark_catalog"); + String source = sourceName("test_remove_column_migrated_table"); + String dest = source; + + String colName1 = "newCol1"; + String colName2 = "newCol2"; + File location = Files.createTempDirectory(temp, "junit").toFile(); + spark + .range(10) + .selectExpr("cast(id as INT)", "CAST(id as INT) " + colName1, "CAST(id as INT) " + colName2) + .write() + .mode(SaveMode.Overwrite) + .saveAsTable(dest); + List expected1 = sql("select id, %s from %s order by id", colName1, source); + List expected2 = sql("select id from %s order by id", source); + + // migrate table + SparkActions.get().migrateTable(source).execute(); + SparkTable sparkTable = loadTable(dest); + Table table = sparkTable.table(); + + // test column removal on migrated table + Schema beforeSchema = table.schema(); + sparkTable.table().updateSchema().deleteColumn(colName1).commit(); + Schema afterSchema = table.schema(); + assertThat(beforeSchema.findField(colName1)).isNotNull(); + assertThat(afterSchema.findField(colName1)).isNull(); + + // reads should succeed without any exceptions + List results1 = sql("select * from %s order by id", dest); + assertThat(results1).isNotEmpty(); + assertEquals("Output must match", expected1, results1); + + sql("ALTER TABLE %s DROP COLUMN %s", dest, colName2); + StructType schema = spark.table(dest).schema(); + assertThat(schema.fieldNames()).doesNotContain(colName2); + + // reads should succeed without any exceptions + List results2 = sql("select * from %s order by id", dest); + assertThat(results2).isNotEmpty(); + assertEquals("Output must match", expected2, results2); + } + + @TestTemplate + public void removeColumnFromMiddle() throws Exception { + assumeThat(type).as("Cannot migrate to a hadoop based catalog").isNotEqualTo("hadoop"); + assumeThat(catalog.name()) + .as("Can only migrate from Spark Session Catalog") + .isEqualTo("spark_catalog"); + String source = sourceName("test_remove_column_migrated_table_from_middle"); + String dest = source; + String dropColumnName = "col1"; + + spark + .range(10) + .selectExpr( + "cast(id as INT)", "CAST(id as INT) as " + dropColumnName, "CAST(id as INT) as col2") + .write() + .mode(SaveMode.Overwrite) + .saveAsTable(dest); + List expected = sql("select id, col2 from %s order by id", source); + + // migrate table + SparkActions.get().migrateTable(source).execute(); + + // drop column + sql("ALTER TABLE %s DROP COLUMN %s", dest, "col1"); + StructType schema = spark.table(dest).schema(); + assertThat(schema.fieldNames()).doesNotContain(dropColumnName); + + // reads should return same output as that of non-iceberg table + List results = sql("select * from %s order by id", dest); + assertThat(results).isNotEmpty(); + assertEquals("Output must match", expected, results); + } + + @TestTemplate + public void testMigrateUnpartitioned() throws Exception { + assumeThat(type).as("Cannot migrate to a hadoop based catalog").isNotEqualTo("hadoop"); + assumeThat(catalog.name()) + .as("Can only migrate from Spark Session Catalog") + .isEqualTo("spark_catalog"); + String source = sourceName("test_migrate_unpartitioned_table"); + String dest = source; + createSourceTable(CREATE_PARQUET, source); + assertMigratedFileCount(SparkActions.get().migrateTable(source), source, dest); + } + + @TestTemplate + public void testSnapshotPartitioned() throws Exception { + assumeThat(type) + .as("Cannot snapshot with arbitrary location in a hadoop based catalog") + .isNotEqualTo("hadoop"); + File location = Files.createTempDirectory(temp, "junit").toFile(); + String source = sourceName("test_snapshot_partitioned_table"); + String dest = destName("iceberg_snapshot_partitioned"); + createSourceTable(CREATE_PARTITIONED_PARQUET, source); + assertSnapshotFileCount( + SparkActions.get().snapshotTable(source).as(dest).tableLocation(location.toString()), + source, + dest); + assertIsolatedSnapshot(source, dest); + } + + @TestTemplate + public void testSnapshotUnpartitioned() throws Exception { + assumeThat(type) + .as("Cannot snapshot with arbitrary location in a hadoop based catalog") + .isNotEqualTo("hadoop"); + File location = Files.createTempDirectory(temp, "junit").toFile(); + String source = sourceName("test_snapshot_unpartitioned_table"); + String dest = destName("iceberg_snapshot_unpartitioned"); + createSourceTable(CREATE_PARQUET, source); + assertSnapshotFileCount( + SparkActions.get().snapshotTable(source).as(dest).tableLocation(location.toString()), + source, + dest); + assertIsolatedSnapshot(source, dest); + } + + @TestTemplate + public void testSnapshotHiveTable() throws Exception { + assumeThat(type) + .as("Cannot snapshot with arbitrary location in a hadoop based catalog") + .isNotEqualTo("hadoop"); + File location = Files.createTempDirectory(temp, "junit").toFile(); + String source = sourceName("snapshot_hive_table"); + String dest = destName("iceberg_snapshot_hive_table"); + createSourceTable(CREATE_HIVE_EXTERNAL_PARQUET, source); + assertSnapshotFileCount( + SparkActions.get().snapshotTable(source).as(dest).tableLocation(location.toString()), + source, + dest); + assertIsolatedSnapshot(source, dest); + } + + @TestTemplate + public void testMigrateHiveTable() throws Exception { + assumeThat(type).as("Cannot migrate to a hadoop based catalog").isNotEqualTo("hadoop"); + String source = sourceName("migrate_hive_table"); + String dest = source; + createSourceTable(CREATE_HIVE_EXTERNAL_PARQUET, source); + assertMigratedFileCount(SparkActions.get().migrateTable(source), source, dest); + } + + @TestTemplate + public void testSnapshotManagedHiveTable() throws Exception { + assumeThat(type).as("Cannot migrate to a hadoop based catalog").isNotEqualTo("hadoop"); + File location = Files.createTempDirectory(temp, "junit").toFile(); + String source = sourceName("snapshot_managed_hive_table"); + String dest = destName("iceberg_snapshot_managed_hive_table"); + createSourceTable(CREATE_HIVE_PARQUET, source); + assertSnapshotFileCount( + SparkActions.get().snapshotTable(source).as(dest).tableLocation(location.toString()), + source, + dest); + assertIsolatedSnapshot(source, dest); + } + + @TestTemplate + public void testMigrateManagedHiveTable() throws Exception { + assumeThat(type).as("Cannot migrate to a hadoop based catalog").isNotEqualTo("hadoop"); + File location = Files.createTempDirectory(temp, "junit").toFile(); + String source = sourceName("migrate_managed_hive_table"); + String dest = destName("iceberg_migrate_managed_hive_table"); + createSourceTable(CREATE_HIVE_PARQUET, source); + assertSnapshotFileCount( + SparkActions.get().snapshotTable(source).as(dest).tableLocation(location.toString()), + source, + dest); + } + + @TestTemplate + public void testProperties() throws Exception { + String source = sourceName("test_properties_table"); + String dest = destName("iceberg_properties"); + Map props = Maps.newHashMap(); + props.put("city", "New Orleans"); + props.put("note", "Jazz"); + createSourceTable(CREATE_PARQUET, source); + for (Map.Entry keyValue : props.entrySet()) { + spark.sql( + String.format( + "ALTER TABLE %s SET TBLPROPERTIES (\"%s\" = \"%s\")", + source, keyValue.getKey(), keyValue.getValue())); + } + assertSnapshotFileCount( + SparkActions.get().snapshotTable(source).as(dest).tableProperty("dogs", "sundance"), + source, + dest); + SparkTable table = loadTable(dest); + + Map expectedProps = Maps.newHashMap(); + expectedProps.putAll(props); + expectedProps.put("dogs", "sundance"); + + for (Map.Entry entry : expectedProps.entrySet()) { + assertThat(table.properties()) + .as("Property value is not the expected value") + .containsEntry(entry.getKey(), entry.getValue()); + } + } + + @TestTemplate + public void testSparkTableReservedProperties() throws Exception { + String destTableName = "iceberg_reserved_properties"; + String source = sourceName("test_reserved_properties_table"); + String dest = destName(destTableName); + createSourceTable(CREATE_PARQUET, source); + SnapshotTableSparkAction action = SparkActions.get().snapshotTable(source).as(dest); + action.tableProperty(TableProperties.FORMAT_VERSION, "1"); + assertSnapshotFileCount(action, source, dest); + SparkTable table = loadTable(dest); + // set sort orders + table.table().replaceSortOrder().asc("id").desc("data").commit(); + + String[] keys = {"provider", "format", "current-snapshot-id", "location", "sort-order"}; + + for (String entry : keys) { + assertThat(table.properties()) + .as("Created table missing reserved property " + entry) + .containsKey(entry); + } + + assertThat(table.properties().get("provider")).as("Unexpected provider").isEqualTo("iceberg"); + assertThat(table.properties().get("format")) + .as("Unexpected provider") + .isEqualTo("iceberg/parquet"); + assertThat(table.properties().get("current-snapshot-id")) + .as("No current-snapshot-id found") + .isNotEqualTo("none"); + assertThat(table.properties().get("location")) + .as("Location isn't correct") + .endsWith(destTableName); + + assertThat(table.properties().get("format-version")) + .as("Unexpected format-version") + .isEqualTo("1"); + table.table().updateProperties().set("format-version", "2").commit(); + assertThat(table.properties().get("format-version")) + .as("Unexpected format-version") + .isEqualTo("2"); + + assertThat(table.properties().get("sort-order")) + .as("Sort-order isn't correct") + .isEqualTo("id ASC NULLS FIRST, data DESC NULLS LAST"); + assertThat(table.properties().get("identifier-fields")) + .as("Identifier fields should be null") + .isNull(); + + table + .table() + .updateSchema() + .allowIncompatibleChanges() + .requireColumn("id") + .setIdentifierFields("id") + .commit(); + assertThat(table.properties().get("identifier-fields")) + .as("Identifier fields aren't correct") + .isEqualTo("[id]"); + } + + @TestTemplate + public void testSnapshotDefaultLocation() throws Exception { + String source = sourceName("test_snapshot_default"); + String dest = destName("iceberg_snapshot_default"); + createSourceTable(CREATE_PARTITIONED_PARQUET, source); + assertSnapshotFileCount(SparkActions.get().snapshotTable(source).as(dest), source, dest); + assertIsolatedSnapshot(source, dest); + } + + @TestTemplate + public void schemaEvolutionTestWithSparkAPI() throws Exception { + assumeThat(type).as("Cannot migrate to a hadoop based catalog").isNotEqualTo("hadoop"); + assumeThat(catalog.name()) + .as("Can only migrate from Spark Session Catalog") + .isEqualTo("spark_catalog"); + + File location = Files.createTempDirectory(temp, "junit").toFile(); + String tblName = sourceName("schema_evolution_test"); + + // Data generation and partition addition + spark + .range(0, 5) + .selectExpr("CAST(id as INT) as col0", "CAST(id AS FLOAT) col2", "CAST(id AS LONG) col3") + .write() + .mode(SaveMode.Append) + .parquet(location.toURI().toString()); + Dataset rowDataset = + spark + .range(6, 10) + .selectExpr( + "CAST(id as INT) as col0", + "CAST(id AS STRING) col1", + "CAST(id AS FLOAT) col2", + "CAST(id AS LONG) col3"); + rowDataset.write().mode(SaveMode.Append).parquet(location.toURI().toString()); + spark + .read() + .schema(rowDataset.schema()) + .parquet(location.toURI().toString()) + .write() + .saveAsTable(tblName); + List expectedBeforeAddColumn = sql("SELECT * FROM %s ORDER BY col0", tblName); + List expectedAfterAddColumn = + sql("SELECT col0, null, col1, col2, col3 FROM %s ORDER BY col0", tblName); + + // Migrate table + SparkActions.get().migrateTable(tblName).execute(); + + // check if iceberg and non-iceberg output + List afterMigarteBeforeAddResults = sql("SELECT * FROM %s ORDER BY col0", tblName); + assertEquals("Output must match", expectedBeforeAddColumn, afterMigarteBeforeAddResults); + + // Update schema and check output correctness + SparkTable sparkTable = loadTable(tblName); + sparkTable + .table() + .updateSchema() + .addColumn("newCol", Types.IntegerType.get()) + .moveAfter("newCol", "col0") + .commit(); + List afterMigarteAfterAddResults = sql("SELECT * FROM %s ORDER BY col0", tblName); + assertEquals("Output must match", expectedAfterAddColumn, afterMigarteAfterAddResults); + } + + @TestTemplate + public void schemaEvolutionTestWithSparkSQL() throws Exception { + assumeThat(type).as("Cannot migrate to a hadoop based catalog").isNotEqualTo("hadoop"); + assumeThat(catalog.name()) + .as("Can only migrate from Spark Session Catalog") + .isEqualTo("spark_catalog"); + String tblName = sourceName("schema_evolution_test_sql"); + + // Data generation and partition addition + spark + .range(0, 5) + .selectExpr("CAST(id as INT) col0", "CAST(id AS FLOAT) col1", "CAST(id AS STRING) col2") + .write() + .mode(SaveMode.Append) + .saveAsTable(tblName); + sql("ALTER TABLE %s ADD COLUMN col3 INT", tblName); + spark + .range(6, 10) + .selectExpr( + "CAST(id AS INT) col0", + "CAST(id AS FLOAT) col1", + "CAST(id AS STRING) col2", + "CAST(id AS INT) col3") + .registerTempTable("tempdata"); + sql("INSERT INTO TABLE %s SELECT * FROM tempdata", tblName); + List expectedBeforeAddColumn = sql("SELECT * FROM %s ORDER BY col0", tblName); + List expectedAfterAddColumn = + sql("SELECT col0, null, col1, col2, col3 FROM %s ORDER BY col0", tblName); + + // Migrate table + SparkActions.get().migrateTable(tblName).execute(); + + // check if iceberg and non-iceberg output + List afterMigarteBeforeAddResults = sql("SELECT * FROM %s ORDER BY col0", tblName); + assertEquals("Output must match", expectedBeforeAddColumn, afterMigarteBeforeAddResults); + + // Update schema and check output correctness + SparkTable sparkTable = loadTable(tblName); + sparkTable + .table() + .updateSchema() + .addColumn("newCol", Types.IntegerType.get()) + .moveAfter("newCol", "col0") + .commit(); + List afterMigarteAfterAddResults = sql("SELECT * FROM %s ORDER BY col0", tblName); + assertEquals("Output must match", expectedAfterAddColumn, afterMigarteAfterAddResults); + } + + @TestTemplate + public void testHiveStyleThreeLevelList() throws Exception { + threeLevelList(true); + } + + @TestTemplate + public void testThreeLevelList() throws Exception { + threeLevelList(false); + } + + @TestTemplate + public void testHiveStyleThreeLevelListWithNestedStruct() throws Exception { + threeLevelListWithNestedStruct(true); + } + + @TestTemplate + public void testThreeLevelListWithNestedStruct() throws Exception { + threeLevelListWithNestedStruct(false); + } + + @TestTemplate + public void testHiveStyleThreeLevelLists() throws Exception { + threeLevelLists(true); + } + + @TestTemplate + public void testThreeLevelLists() throws Exception { + threeLevelLists(false); + } + + @TestTemplate + public void testHiveStyleStructOfThreeLevelLists() throws Exception { + structOfThreeLevelLists(true); + } + + @TestTemplate + public void testStructOfThreeLevelLists() throws Exception { + structOfThreeLevelLists(false); + } + + @TestTemplate + public void testTwoLevelList() throws IOException { + assumeThat(type).as("Cannot migrate to a hadoop based catalog").isNotEqualTo("hadoop"); + + spark.conf().set("spark.sql.parquet.writeLegacyFormat", true); + + String tableName = sourceName("testTwoLevelList"); + File location = Files.createTempDirectory(temp, "junit").toFile(); + + StructType sparkSchema = + new StructType( + new StructField[] { + new StructField( + "col1", + new ArrayType( + new StructType( + new StructField[] { + new StructField("col2", DataTypes.IntegerType, false, Metadata.empty()) + }), + false), + true, + Metadata.empty()) + }); + + // even though this list looks like three level list, it is actually a 2-level list where the + // items are + // structs with 1 field. + String expectedParquetSchema = + "message spark_schema {\n" + + " optional group col1 (LIST) {\n" + + " repeated group array {\n" + + " required int32 col2;\n" + + " }\n" + + " }\n" + + "}\n"; + + // generate parquet file with required schema + List testData = Collections.singletonList("{\"col1\": [{\"col2\": 1}]}"); + spark + .read() + .schema(sparkSchema) + .json(JavaSparkContext.fromSparkContext(spark.sparkContext()).parallelize(testData)) + .coalesce(1) + .write() + .format("parquet") + .mode(SaveMode.Append) + .save(location.getPath()); + + File parquetFile = + Arrays.stream( + Preconditions.checkNotNull( + location.listFiles( + new FilenameFilter() { + @Override + public boolean accept(File dir, String name) { + return name.endsWith("parquet"); + } + }))) + .findAny() + .get(); + + // verify generated parquet file has expected schema + ParquetFileReader pqReader = + ParquetFileReader.open( + HadoopInputFile.fromPath( + new Path(parquetFile.getPath()), spark.sessionState().newHadoopConf())); + MessageType schema = pqReader.getFooter().getFileMetaData().getSchema(); + assertThat(schema).isEqualTo(MessageTypeParser.parseMessageType(expectedParquetSchema)); + + // create sql table on top of it + sql( + "CREATE EXTERNAL TABLE %s (col1 ARRAY>)" + + " STORED AS parquet" + + " LOCATION '%s'", + tableName, location); + List expected = sql("select array(struct(1))"); + + // migrate table + SparkActions.get().migrateTable(tableName).execute(); + + // check migrated table is returning expected result + List results = sql("SELECT * FROM %s", tableName); + assertThat(results).isNotEmpty(); + assertEquals("Output must match", expected, results); + } + + private void threeLevelList(boolean useLegacyMode) throws Exception { + assumeThat(type).as("Cannot migrate to a hadoop based catalog").isNotEqualTo("hadoop"); + + spark.conf().set("spark.sql.parquet.writeLegacyFormat", useLegacyMode); + + String tableName = sourceName(String.format("threeLevelList_%s", useLegacyMode)); + File location = Files.createTempDirectory(temp, "junit").toFile(); + sql( + "CREATE TABLE %s (col1 ARRAY>)" + " STORED AS parquet" + " LOCATION '%s'", + tableName, location); + + int testValue = 12345; + sql("INSERT INTO %s VALUES (ARRAY(STRUCT(%s)))", tableName, testValue); + List expected = sql(String.format("SELECT * FROM %s", tableName)); + + // migrate table + SparkActions.get().migrateTable(tableName).execute(); + + // check migrated table is returning expected result + List results = sql("SELECT * FROM %s", tableName); + assertThat(results).isNotEmpty(); + assertEquals("Output must match", expected, results); + } + + private void threeLevelListWithNestedStruct(boolean useLegacyMode) throws Exception { + assumeThat(type).as("Cannot migrate to a hadoop based catalog").isNotEqualTo("hadoop"); + + spark.conf().set("spark.sql.parquet.writeLegacyFormat", useLegacyMode); + + String tableName = + sourceName(String.format("threeLevelListWithNestedStruct_%s", useLegacyMode)); + File location = Files.createTempDirectory(temp, "junit").toFile(); + sql( + "CREATE TABLE %s (col1 ARRAY>>)" + + " STORED AS parquet" + + " LOCATION '%s'", + tableName, location); + + int testValue = 12345; + sql("INSERT INTO %s VALUES (ARRAY(STRUCT(STRUCT(%s))))", tableName, testValue); + List expected = sql(String.format("SELECT * FROM %s", tableName)); + + // migrate table + SparkActions.get().migrateTable(tableName).execute(); + + // check migrated table is returning expected result + List results = sql("SELECT * FROM %s", tableName); + assertThat(results).isNotEmpty(); + assertEquals("Output must match", expected, results); + } + + private void threeLevelLists(boolean useLegacyMode) throws Exception { + assumeThat(type).as("Cannot migrate to a hadoop based catalog").isNotEqualTo("hadoop"); + + spark.conf().set("spark.sql.parquet.writeLegacyFormat", useLegacyMode); + + String tableName = sourceName(String.format("threeLevelLists_%s", useLegacyMode)); + File location = Files.createTempDirectory(temp, "junit").toFile(); + sql( + "CREATE TABLE %s (col1 ARRAY>, col3 ARRAY>)" + + " STORED AS parquet" + + " LOCATION '%s'", + tableName, location); + + int testValue1 = 12345; + int testValue2 = 987654; + sql( + "INSERT INTO %s VALUES (ARRAY(STRUCT(%s)), ARRAY(STRUCT(%s)))", + tableName, testValue1, testValue2); + List expected = sql(String.format("SELECT * FROM %s", tableName)); + + // migrate table + SparkActions.get().migrateTable(tableName).execute(); + + // check migrated table is returning expected result + List results = sql("SELECT * FROM %s", tableName); + assertThat(results).isNotEmpty(); + assertEquals("Output must match", expected, results); + } + + private void structOfThreeLevelLists(boolean useLegacyMode) throws Exception { + assumeThat(type).as("Cannot migrate to a hadoop based catalog").isNotEqualTo("hadoop"); + + spark.conf().set("spark.sql.parquet.writeLegacyFormat", useLegacyMode); + + String tableName = sourceName(String.format("structOfThreeLevelLists_%s", useLegacyMode)); + File location = Files.createTempDirectory(temp, "junit").toFile(); + sql( + "CREATE TABLE %s (col1 STRUCT>>)" + + " STORED AS parquet" + + " LOCATION '%s'", + tableName, location); + + int testValue1 = 12345; + sql("INSERT INTO %s VALUES (STRUCT(STRUCT(ARRAY(STRUCT(%s)))))", tableName, testValue1); + List expected = sql(String.format("SELECT * FROM %s", tableName)); + + // migrate table + SparkActions.get().migrateTable(tableName).execute(); + + // check migrated table is returning expected result + List results = sql("SELECT * FROM %s", tableName); + assertThat(results).isNotEmpty(); + assertEquals("Output must match", expected, results); + } + + private SparkTable loadTable(String name) throws NoSuchTableException, ParseException { + return (SparkTable) + catalog.loadTable(Spark3Util.catalogAndIdentifier(spark, name).identifier()); + } + + private CatalogTable loadSessionTable(String name) + throws NoSuchTableException, NoSuchDatabaseException, ParseException { + Identifier identifier = Spark3Util.catalogAndIdentifier(spark, name).identifier(); + Some namespace = Some.apply(identifier.namespace()[0]); + return spark + .sessionState() + .catalog() + .getTableMetadata(new TableIdentifier(identifier.name(), namespace)); + } + + private void createSourceTable(String createStatement, String tableName) + throws IOException, NoSuchTableException, NoSuchDatabaseException, ParseException { + File location = Files.createTempDirectory(temp, "junit").toFile(); + spark.sql(String.format(createStatement, tableName, location)); + CatalogTable table = loadSessionTable(tableName); + String format = table.provider().get(); + spark + .table(baseTableName) + .selectExpr(table.schema().names()) + .write() + .mode(SaveMode.Append) + .format(format) + .insertInto(tableName); + } + + // Counts the number of files in the source table, makes sure the same files exist in the + // destination table + private void assertMigratedFileCount(MigrateTable migrateAction, String source, String dest) + throws NoSuchTableException, NoSuchDatabaseException, ParseException { + long expectedFiles = expectedFilesCount(source); + MigrateTable.Result migratedFiles = migrateAction.execute(); + validateTables(source, dest); + assertThat(migratedFiles.migratedDataFilesCount()) + .as("Expected number of migrated files") + .isEqualTo(expectedFiles); + } + + // Counts the number of files in the source table, makes sure the same files exist in the + // destination table + private void assertSnapshotFileCount(SnapshotTable snapshotTable, String source, String dest) + throws NoSuchTableException, NoSuchDatabaseException, ParseException { + long expectedFiles = expectedFilesCount(source); + SnapshotTable.Result snapshotTableResult = snapshotTable.execute(); + validateTables(source, dest); + assertThat(snapshotTableResult.importedDataFilesCount()) + .as("Expected number of imported snapshot files") + .isEqualTo(expectedFiles); + } + + private void validateTables(String source, String dest) + throws NoSuchTableException, ParseException { + List expected = spark.table(source).collectAsList(); + SparkTable destTable = loadTable(dest); + assertThat(destTable.properties().get(TableCatalog.PROP_PROVIDER)) + .as("Provider should be iceberg") + .isEqualTo("iceberg"); + List actual = spark.table(dest).collectAsList(); + assertThat(actual) + .as( + String.format( + "Rows in migrated table did not match\nExpected :%s rows \nFound :%s", + expected, actual)) + .containsAll(expected); + assertThat(expected) + .as( + String.format( + "Rows in migrated table did not match\nExpected :%s rows \nFound :%s", + expected, actual)) + .containsAll(actual); + } + + private long expectedFilesCount(String source) + throws NoSuchDatabaseException, NoSuchTableException, ParseException { + CatalogTable sourceTable = loadSessionTable(source); + List uris; + if (sourceTable.partitionColumnNames().isEmpty()) { + uris = Lists.newArrayList(); + uris.add(sourceTable.location()); + } else { + Seq catalogTablePartitionSeq = + spark + .sessionState() + .catalog() + .listPartitions(sourceTable.identifier(), Option.apply(null)); + uris = + JavaConverters.seqAsJavaList(catalogTablePartitionSeq).stream() + .map(CatalogTablePartition::location) + .collect(Collectors.toList()); + } + return uris.stream() + .flatMap( + uri -> + FileUtils.listFiles( + Paths.get(uri).toFile(), TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE) + .stream()) + .filter(file -> !file.toString().endsWith("crc") && !file.toString().contains("_SUCCESS")) + .count(); + } + + // Insert records into the destination, makes sure those records exist and source table is + // unchanged + private void assertIsolatedSnapshot(String source, String dest) { + List expected = spark.sql(String.format("SELECT * FROM %s", source)).collectAsList(); + + List extraData = Lists.newArrayList(new SimpleRecord(4, "d")); + Dataset df = spark.createDataFrame(extraData, SimpleRecord.class); + df.write().format("iceberg").mode("append").saveAsTable(dest); + + List result = spark.sql(String.format("SELECT * FROM %s", source)).collectAsList(); + assertThat(result) + .as("No additional rows should be added to the original table") + .hasSameSizeAs(expected); + + List snapshot = + spark + .sql(String.format("SELECT * FROM %s WHERE id = 4 AND data = 'd'", dest)) + .collectAsList(); + assertThat(snapshot).as("Added row not found in snapshot").hasSize(1); + } + + private String sourceName(String source) { + return NAMESPACE + "." + catalog.name() + "_" + type + "_" + source; + } + + private String destName(String dest) { + if (catalog.name().equals("spark_catalog")) { + return NAMESPACE + "." + catalog.name() + "_" + type + "_" + dest; + } else { + return catalog.name() + "." + NAMESPACE + "." + catalog.name() + "_" + type + "_" + dest; + } + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestDeleteReachableFilesAction.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestDeleteReachableFilesAction.java new file mode 100644 index 000000000000..d5bb63b2d88a --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestDeleteReachableFilesAction.java @@ -0,0 +1,406 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.io.File; +import java.util.Arrays; +import java.util.List; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.StreamSupport; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileGenerationUtil; +import org.apache.iceberg.FileMetadata; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.TestHelpers; +import org.apache.iceberg.actions.ActionsProvider; +import org.apache.iceberg.actions.DeleteOrphanFiles; +import org.apache.iceberg.actions.DeleteReachableFiles; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.TestBase; +import org.apache.iceberg.types.Types; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestDeleteReachableFilesAction extends TestBase { + private static final HadoopTables TABLES = new HadoopTables(new Configuration()); + private static final Schema SCHEMA = + new Schema( + optional(1, "c1", Types.IntegerType.get()), + optional(2, "c2", Types.StringType.get()), + optional(3, "c3", Types.StringType.get())); + private static final int SHUFFLE_PARTITIONS = 2; + + private static final PartitionSpec SPEC = PartitionSpec.builderFor(SCHEMA).identity("c1").build(); + + static final DataFile FILE_A = + DataFiles.builder(SPEC) + .withPath("/path/to/data-a.parquet") + .withFileSizeInBytes(10) + .withPartition(TestHelpers.Row.of(0)) + .withRecordCount(1) + .build(); + static final DataFile FILE_B = + DataFiles.builder(SPEC) + .withPath("/path/to/data-b.parquet") + .withFileSizeInBytes(10) + .withPartition(TestHelpers.Row.of(1)) + .withRecordCount(1) + .build(); + static final DataFile FILE_C = + DataFiles.builder(SPEC) + .withPath("/path/to/data-c.parquet") + .withFileSizeInBytes(10) + .withPartition(TestHelpers.Row.of(2)) + .withRecordCount(1) + .build(); + static final DataFile FILE_D = + DataFiles.builder(SPEC) + .withPath("/path/to/data-d.parquet") + .withFileSizeInBytes(10) + .withPartition(TestHelpers.Row.of(3)) + .withRecordCount(1) + .build(); + static final DeleteFile FILE_A_POS_DELETES = + FileMetadata.deleteFileBuilder(SPEC) + .ofPositionDeletes() + .withPath("/path/to/data-a-pos-deletes.parquet") + .withFileSizeInBytes(10) + .withPartition(TestHelpers.Row.of(0)) + .withRecordCount(1) + .build(); + static final DeleteFile FILE_A_EQ_DELETES = + FileMetadata.deleteFileBuilder(SPEC) + .ofEqualityDeletes() + .withPath("/path/to/data-a-eq-deletes.parquet") + .withFileSizeInBytes(10) + .withPartition(TestHelpers.Row.of(0)) + .withRecordCount(1) + .build(); + + @TempDir private File tableDir; + @Parameter private int formatVersion; + + @Parameters(name = "formatVersion = {0}") + protected static List parameters() { + return Arrays.asList(2, 3); + } + + private Table table; + + @BeforeEach + public void setupTableLocation() throws Exception { + String tableLocation = tableDir.toURI().toString(); + this.table = + TABLES.create( + SCHEMA, + SPEC, + ImmutableMap.of(TableProperties.FORMAT_VERSION, String.valueOf(formatVersion)), + tableLocation); + spark.conf().set("spark.sql.shuffle.partitions", SHUFFLE_PARTITIONS); + } + + private void checkRemoveFilesResults( + long expectedDatafiles, + long expectedPosDeleteFiles, + long expectedEqDeleteFiles, + long expectedManifestsDeleted, + long expectedManifestListsDeleted, + long expectedOtherFilesDeleted, + DeleteReachableFiles.Result results) { + assertThat(results.deletedManifestsCount()) + .as("Incorrect number of manifest files deleted") + .isEqualTo(expectedManifestsDeleted); + + assertThat(results.deletedDataFilesCount()) + .as("Incorrect number of datafiles deleted") + .isEqualTo(expectedDatafiles); + + assertThat(results.deletedPositionDeleteFilesCount()) + .as("Incorrect number of position delete files deleted") + .isEqualTo(expectedPosDeleteFiles); + + assertThat(results.deletedEqualityDeleteFilesCount()) + .as("Incorrect number of equality delete files deleted") + .isEqualTo(expectedEqDeleteFiles); + + assertThat(results.deletedManifestListsCount()) + .as("Incorrect number of manifest lists deleted") + .isEqualTo(expectedManifestListsDeleted); + + assertThat(results.deletedOtherFilesCount()) + .as("Incorrect number of other lists deleted") + .isEqualTo(expectedOtherFilesDeleted); + } + + @TestTemplate + public void dataFilesCleanupWithParallelTasks() { + table.newFastAppend().appendFile(FILE_A).commit(); + + table.newFastAppend().appendFile(FILE_B).commit(); + + table.newRewrite().rewriteFiles(ImmutableSet.of(FILE_B), ImmutableSet.of(FILE_D)).commit(); + + table.newRewrite().rewriteFiles(ImmutableSet.of(FILE_A), ImmutableSet.of(FILE_C)).commit(); + + Set deletedFiles = ConcurrentHashMap.newKeySet(); + Set deleteThreads = ConcurrentHashMap.newKeySet(); + AtomicInteger deleteThreadsIndex = new AtomicInteger(0); + + DeleteReachableFiles.Result result = + sparkActions() + .deleteReachableFiles(metadataLocation(table)) + .io(table.io()) + .executeDeleteWith( + Executors.newFixedThreadPool( + 4, + runnable -> { + Thread thread = new Thread(runnable); + thread.setName("remove-files-" + deleteThreadsIndex.getAndIncrement()); + thread.setDaemon( + true); // daemon threads will be terminated abruptly when the JVM exits + return thread; + })) + .deleteWith( + s -> { + deleteThreads.add(Thread.currentThread().getName()); + deletedFiles.add(s); + }) + .execute(); + + // Verifies that the delete methods ran in the threads created by the provided ExecutorService + // ThreadFactory + assertThat(deleteThreads) + .isEqualTo( + Sets.newHashSet( + "remove-files-0", "remove-files-1", "remove-files-2", "remove-files-3")); + + Lists.newArrayList(FILE_A, FILE_B, FILE_C, FILE_D) + .forEach( + file -> + assertThat(deletedFiles) + .as("FILE_A should be deleted") + .contains(FILE_A.location())); + checkRemoveFilesResults(4L, 0, 0, 6L, 4L, 6, result); + } + + @TestTemplate + public void testWithExpiringDanglingStageCommit() { + table.location(); + // `A` commit + table.newAppend().appendFile(FILE_A).commit(); + + // `B` staged commit + table.newAppend().appendFile(FILE_B).stageOnly().commit(); + + // `C` commit + table.newAppend().appendFile(FILE_C).commit(); + + DeleteReachableFiles.Result result = + sparkActions().deleteReachableFiles(metadataLocation(table)).io(table.io()).execute(); + + checkRemoveFilesResults(3L, 0, 0, 3L, 3L, 5, result); + } + + @TestTemplate + public void testRemoveFileActionOnEmptyTable() { + DeleteReachableFiles.Result result = + sparkActions().deleteReachableFiles(metadataLocation(table)).io(table.io()).execute(); + + checkRemoveFilesResults(0, 0, 0, 0, 0, 2, result); + } + + @TestTemplate + public void testRemoveFilesActionWithReducedVersionsTable() { + table.updateProperties().set(TableProperties.METADATA_PREVIOUS_VERSIONS_MAX, "2").commit(); + table.newAppend().appendFile(FILE_A).commit(); + + table.newAppend().appendFile(FILE_B).commit(); + + table.newAppend().appendFile(FILE_B).commit(); + + table.newAppend().appendFile(FILE_C).commit(); + + table.newAppend().appendFile(FILE_D).commit(); + + DeleteReachableFiles baseRemoveFilesSparkAction = + sparkActions().deleteReachableFiles(metadataLocation(table)).io(table.io()); + DeleteReachableFiles.Result result = baseRemoveFilesSparkAction.execute(); + + checkRemoveFilesResults(4, 0, 0, 5, 5, 8, result); + } + + @TestTemplate + public void testRemoveFilesAction() { + table.newAppend().appendFile(FILE_A).commit(); + + table.newAppend().appendFile(FILE_B).commit(); + + DeleteReachableFiles baseRemoveFilesSparkAction = + sparkActions().deleteReachableFiles(metadataLocation(table)).io(table.io()); + checkRemoveFilesResults(2, 0, 0, 2, 2, 4, baseRemoveFilesSparkAction.execute()); + } + + @TestTemplate + public void testPositionDeleteFiles() { + table.newAppend().appendFile(FILE_A).commit(); + + table.newAppend().appendFile(FILE_B).commit(); + + table.newRowDelta().addDeletes(fileADeletes()).commit(); + + DeleteReachableFiles baseRemoveFilesSparkAction = + sparkActions().deleteReachableFiles(metadataLocation(table)).io(table.io()); + checkRemoveFilesResults(2, 1, 0, 3, 3, 5, baseRemoveFilesSparkAction.execute()); + } + + @TestTemplate + public void testEqualityDeleteFiles() { + table.newAppend().appendFile(FILE_A).commit(); + + table.newAppend().appendFile(FILE_B).commit(); + + table.newRowDelta().addDeletes(FILE_A_EQ_DELETES).commit(); + + DeleteReachableFiles baseRemoveFilesSparkAction = + sparkActions().deleteReachableFiles(metadataLocation(table)).io(table.io()); + checkRemoveFilesResults(2, 0, 1, 3, 3, 5, baseRemoveFilesSparkAction.execute()); + } + + @TestTemplate + public void testRemoveFilesActionWithDefaultIO() { + table.newAppend().appendFile(FILE_A).commit(); + + table.newAppend().appendFile(FILE_B).commit(); + + // IO not set explicitly on removeReachableFiles action + // IO defaults to HadoopFileIO + DeleteReachableFiles baseRemoveFilesSparkAction = + sparkActions().deleteReachableFiles(metadataLocation(table)); + checkRemoveFilesResults(2, 0, 0, 2, 2, 4, baseRemoveFilesSparkAction.execute()); + } + + @TestTemplate + public void testUseLocalIterator() { + table.newFastAppend().appendFile(FILE_A).commit(); + + table.newOverwrite().deleteFile(FILE_A).addFile(FILE_B).commit(); + + table.newFastAppend().appendFile(FILE_C).commit(); + + int jobsBefore = spark.sparkContext().dagScheduler().nextJobId().get(); + + withSQLConf( + ImmutableMap.of("spark.sql.adaptive.enabled", "false"), + () -> { + DeleteReachableFiles.Result results = + sparkActions() + .deleteReachableFiles(metadataLocation(table)) + .io(table.io()) + .option("stream-results", "true") + .execute(); + + int jobsAfter = spark.sparkContext().dagScheduler().nextJobId().get(); + int totalJobsRun = jobsAfter - jobsBefore; + + checkRemoveFilesResults(3L, 0, 0, 4L, 3L, 5, results); + + assertThat(totalJobsRun) + .as("Expected total jobs to be equal to total number of shuffle partitions") + .isEqualTo(SHUFFLE_PARTITIONS); + }); + } + + @TestTemplate + public void testIgnoreMetadataFilesNotFound() { + table.updateProperties().set(TableProperties.METADATA_PREVIOUS_VERSIONS_MAX, "1").commit(); + + table.newAppend().appendFile(FILE_A).commit(); + // There are three metadata json files at this point + DeleteOrphanFiles.Result result = + sparkActions().deleteOrphanFiles(table).olderThan(System.currentTimeMillis()).execute(); + + assertThat(result.orphanFileLocations()).as("Should delete 1 file").hasSize(1); + assertThat(StreamSupport.stream(result.orphanFileLocations().spliterator(), false)) + .as("Should remove v1 file") + .anyMatch(file -> file.contains("v1.metadata.json")); + + DeleteReachableFiles baseRemoveFilesSparkAction = + sparkActions().deleteReachableFiles(metadataLocation(table)).io(table.io()); + DeleteReachableFiles.Result res = baseRemoveFilesSparkAction.execute(); + + checkRemoveFilesResults(1, 0, 0, 1, 1, 4, res); + } + + @TestTemplate + public void testEmptyIOThrowsException() { + DeleteReachableFiles baseRemoveFilesSparkAction = + sparkActions().deleteReachableFiles(metadataLocation(table)).io(null); + + assertThatThrownBy(baseRemoveFilesSparkAction::execute) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("File IO cannot be null"); + } + + @TestTemplate + public void testRemoveFilesActionWhenGarbageCollectionDisabled() { + table.updateProperties().set(TableProperties.GC_ENABLED, "false").commit(); + + assertThatThrownBy(() -> sparkActions().deleteReachableFiles(metadataLocation(table)).execute()) + .isInstanceOf(ValidationException.class) + .hasMessage( + "Cannot delete files: GC is disabled (deleting files may corrupt other tables)"); + } + + private String metadataLocation(Table tbl) { + return ((HasTableOperations) tbl).operations().current().metadataFileLocation(); + } + + private ActionsProvider sparkActions() { + return SparkActions.get(); + } + + private DeleteFile fileADeletes() { + return formatVersion >= 3 ? FileGenerationUtil.generateDV(table, FILE_A) : FILE_A_POS_DELETES; + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestExpireSnapshotsAction.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestExpireSnapshotsAction.java new file mode 100644 index 000000000000..ffbe988e8d41 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestExpireSnapshotsAction.java @@ -0,0 +1,1359 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assumptions.assumeThat; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.List; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileGenerationUtil; +import org.apache.iceberg.FileMetadata; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.ReachableFileUtil; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.actions.ExpireSnapshots; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.TestBase; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestExpireSnapshotsAction extends TestBase { + private static final HadoopTables TABLES = new HadoopTables(new Configuration()); + private static final Schema SCHEMA = + new Schema( + optional(1, "c1", Types.IntegerType.get()), + optional(2, "c2", Types.StringType.get()), + optional(3, "c3", Types.StringType.get())); + private static final int SHUFFLE_PARTITIONS = 2; + + private static final PartitionSpec SPEC = PartitionSpec.builderFor(SCHEMA).identity("c1").build(); + + static final DataFile FILE_A = + DataFiles.builder(SPEC) + .withPath("/path/to/data-a.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=0") // easy way to set partition data for now + .withRecordCount(1) + .build(); + static final DataFile FILE_B = + DataFiles.builder(SPEC) + .withPath("/path/to/data-b.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=1") // easy way to set partition data for now + .withRecordCount(1) + .build(); + static final DataFile FILE_C = + DataFiles.builder(SPEC) + .withPath("/path/to/data-c.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=2") // easy way to set partition data for now + .withRecordCount(1) + .build(); + static final DataFile FILE_D = + DataFiles.builder(SPEC) + .withPath("/path/to/data-d.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=3") // easy way to set partition data for now + .withRecordCount(1) + .build(); + static final DeleteFile FILE_A_POS_DELETES = + FileMetadata.deleteFileBuilder(SPEC) + .ofPositionDeletes() + .withPath("/path/to/data-a-pos-deletes.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=0") // easy way to set partition data for now + .withRecordCount(1) + .build(); + static final DeleteFile FILE_A_EQ_DELETES = + FileMetadata.deleteFileBuilder(SPEC) + .ofEqualityDeletes() + .withPath("/path/to/data-a-eq-deletes.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=0") // easy way to set partition data for now + .withRecordCount(1) + .build(); + + @TempDir private Path temp; + @Parameter private int formatVersion; + + @Parameters(name = "formatVersion = {0}") + protected static List parameters() { + return Arrays.asList(2, 3); + } + + @TempDir private File tableDir; + private String tableLocation; + private Table table; + + @BeforeEach + public void setupTableLocation() throws Exception { + this.tableLocation = tableDir.toURI().toString(); + this.table = + TABLES.create( + SCHEMA, + SPEC, + ImmutableMap.of(TableProperties.FORMAT_VERSION, String.valueOf(formatVersion)), + tableLocation); + spark.conf().set("spark.sql.shuffle.partitions", SHUFFLE_PARTITIONS); + } + + private Long rightAfterSnapshot() { + return rightAfterSnapshot(table.currentSnapshot().snapshotId()); + } + + private Long rightAfterSnapshot(long snapshotId) { + Long end = System.currentTimeMillis(); + while (end <= table.snapshot(snapshotId).timestampMillis()) { + end = System.currentTimeMillis(); + } + return end; + } + + private DeleteFile fileADeletes() { + return formatVersion >= 3 ? FileGenerationUtil.generateDV(table, FILE_A) : FILE_A_POS_DELETES; + } + + private void checkExpirationResults( + long expectedDatafiles, + long expectedPosDeleteFiles, + long expectedEqDeleteFiles, + long expectedManifestsDeleted, + long expectedManifestListsDeleted, + ExpireSnapshots.Result results) { + + assertThat(results.deletedManifestsCount()) + .as("Incorrect number of manifest files deleted") + .isEqualTo(expectedManifestsDeleted); + + assertThat(results.deletedDataFilesCount()) + .as("Incorrect number of datafiles deleted") + .isEqualTo(expectedDatafiles); + + assertThat(results.deletedPositionDeleteFilesCount()) + .as("Incorrect number of pos deletefiles deleted") + .isEqualTo(expectedPosDeleteFiles); + + assertThat(results.deletedEqualityDeleteFilesCount()) + .as("Incorrect number of eq deletefiles deleted") + .isEqualTo(expectedEqDeleteFiles); + + assertThat(results.deletedManifestListsCount()) + .as("Incorrect number of manifest lists deleted") + .isEqualTo(expectedManifestListsDeleted); + } + + @TestTemplate + public void testFilesCleaned() throws Exception { + table.newFastAppend().appendFile(FILE_A).commit(); + + table.newOverwrite().deleteFile(FILE_A).addFile(FILE_B).commit(); + + table.newFastAppend().appendFile(FILE_C).commit(); + + long end = rightAfterSnapshot(); + + ExpireSnapshots.Result results = + SparkActions.get().expireSnapshots(table).expireOlderThan(end).execute(); + + assertThat(table.snapshots()).as("Table does not have 1 snapshot after expiration").hasSize(1); + + checkExpirationResults(1L, 0L, 0L, 1L, 2L, results); + } + + @TestTemplate + public void dataFilesCleanupWithParallelTasks() throws IOException { + + table.newFastAppend().appendFile(FILE_A).commit(); + + table.newFastAppend().appendFile(FILE_B).commit(); + + table.newRewrite().rewriteFiles(ImmutableSet.of(FILE_B), ImmutableSet.of(FILE_D)).commit(); + + table.newRewrite().rewriteFiles(ImmutableSet.of(FILE_A), ImmutableSet.of(FILE_C)).commit(); + + long t4 = rightAfterSnapshot(); + + Set deletedFiles = ConcurrentHashMap.newKeySet(); + Set deleteThreads = ConcurrentHashMap.newKeySet(); + AtomicInteger deleteThreadsIndex = new AtomicInteger(0); + + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .executeDeleteWith( + Executors.newFixedThreadPool( + 4, + runnable -> { + Thread thread = new Thread(runnable); + thread.setName("remove-snapshot-" + deleteThreadsIndex.getAndIncrement()); + thread.setDaemon( + true); // daemon threads will be terminated abruptly when the JVM exits + return thread; + })) + .expireOlderThan(t4) + .deleteWith( + s -> { + deleteThreads.add(Thread.currentThread().getName()); + deletedFiles.add(s); + }) + .execute(); + + // Verifies that the delete methods ran in the threads created by the provided ExecutorService + // ThreadFactory + assertThat(deleteThreads) + .isEqualTo( + Sets.newHashSet( + "remove-snapshot-0", + "remove-snapshot-1", + "remove-snapshot-2", + "remove-snapshot-3")); + + assertThat(deletedFiles).as("FILE_A should be deleted").contains(FILE_A.location()); + assertThat(deletedFiles).as("FILE_B should be deleted").contains(FILE_B.location()); + + checkExpirationResults(2L, 0L, 0L, 3L, 3L, result); + } + + @TestTemplate + public void testNoFilesDeletedWhenNoSnapshotsExpired() throws Exception { + table.newFastAppend().appendFile(FILE_A).commit(); + + ExpireSnapshots.Result results = SparkActions.get().expireSnapshots(table).execute(); + checkExpirationResults(0L, 0L, 0L, 0L, 0L, results); + } + + @TestTemplate + public void testCleanupRepeatedOverwrites() throws Exception { + table.newFastAppend().appendFile(FILE_A).commit(); + + for (int i = 0; i < 10; i++) { + table.newOverwrite().deleteFile(FILE_A).addFile(FILE_B).commit(); + + table.newOverwrite().deleteFile(FILE_B).addFile(FILE_A).commit(); + } + + long end = rightAfterSnapshot(); + ExpireSnapshots.Result results = + SparkActions.get().expireSnapshots(table).expireOlderThan(end).execute(); + checkExpirationResults(1L, 0L, 0L, 39L, 20L, results); + } + + @TestTemplate + public void testRetainLastWithExpireOlderThan() { + table + .newAppend() + .appendFile(FILE_A) // data_bucket=0 + .commit(); + long firstSnapshotId = table.currentSnapshot().snapshotId(); + long t1 = System.currentTimeMillis(); + while (t1 <= table.currentSnapshot().timestampMillis()) { + t1 = System.currentTimeMillis(); + } + + table + .newAppend() + .appendFile(FILE_B) // data_bucket=1 + .commit(); + + table + .newAppend() + .appendFile(FILE_C) // data_bucket=2 + .commit(); + + long t3 = rightAfterSnapshot(); + + // Retain last 2 snapshots + SparkActions.get().expireSnapshots(table).expireOlderThan(t3).retainLast(2).execute(); + + assertThat(table.snapshots()).as("Should have two snapshots.").hasSize(2); + assertThat(table.snapshot(firstSnapshotId)).as("First snapshot should not present.").isNull(); + } + + @TestTemplate + public void testExpireTwoSnapshotsById() throws Exception { + table + .newAppend() + .appendFile(FILE_A) // data_bucket=0 + .commit(); + long firstSnapshotId = table.currentSnapshot().snapshotId(); + + table + .newAppend() + .appendFile(FILE_B) // data_bucket=1 + .commit(); + + long secondSnapshotID = table.currentSnapshot().snapshotId(); + + table + .newAppend() + .appendFile(FILE_C) // data_bucket=2 + .commit(); + + // Retain last 2 snapshots + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .expireSnapshotId(firstSnapshotId) + .expireSnapshotId(secondSnapshotID) + .execute(); + + assertThat(table.snapshots()).as("Should have one snapshot.").hasSize(1); + assertThat(table.snapshot(firstSnapshotId)).as("First snapshot should not present.").isNull(); + assertThat(table.snapshot(secondSnapshotID)).as("Second snapshot should not present.").isNull(); + + checkExpirationResults(0L, 0L, 0L, 0L, 2L, result); + } + + @TestTemplate + public void testRetainLastWithExpireById() { + table + .newAppend() + .appendFile(FILE_A) // data_bucket=0 + .commit(); + long firstSnapshotId = table.currentSnapshot().snapshotId(); + + table + .newAppend() + .appendFile(FILE_B) // data_bucket=1 + .commit(); + + table + .newAppend() + .appendFile(FILE_C) // data_bucket=2 + .commit(); + + // Retain last 3 snapshots, but explicitly remove the first snapshot + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .expireSnapshotId(firstSnapshotId) + .retainLast(3) + .execute(); + + assertThat(table.snapshots()).as("Should have 2 snapshots.").hasSize(2); + assertThat(table.snapshot(firstSnapshotId)).as("First snapshot should not present.").isNull(); + checkExpirationResults(0L, 0L, 0L, 0L, 1L, result); + } + + @TestTemplate + public void testRetainLastWithTooFewSnapshots() { + table + .newAppend() + .appendFile(FILE_A) // data_bucket=0 + .appendFile(FILE_B) // data_bucket=1 + .commit(); + long firstSnapshotId = table.currentSnapshot().snapshotId(); + + table + .newAppend() + .appendFile(FILE_C) // data_bucket=2 + .commit(); + + long t2 = rightAfterSnapshot(); + + // Retain last 3 snapshots + ExpireSnapshots.Result result = + SparkActions.get().expireSnapshots(table).expireOlderThan(t2).retainLast(3).execute(); + + assertThat(table.snapshots()).as("Should have two snapshots.").hasSize(2); + assertThat(table.snapshot(firstSnapshotId)) + .as("First snapshot should still be present.") + .isNotNull(); + checkExpirationResults(0L, 0L, 0L, 0L, 0L, result); + } + + @TestTemplate + public void testRetainLastKeepsExpiringSnapshot() { + table + .newAppend() + .appendFile(FILE_A) // data_bucket=0 + .commit(); + + table + .newAppend() + .appendFile(FILE_B) // data_bucket=1 + .commit(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + table + .newAppend() + .appendFile(FILE_C) // data_bucket=2 + .commit(); + + table + .newAppend() + .appendFile(FILE_D) // data_bucket=3 + .commit(); + + // Retain last 2 snapshots and expire older than t3 + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .expireOlderThan(secondSnapshot.timestampMillis()) + .retainLast(2) + .execute(); + + assertThat(table.snapshots()).as("Should have three snapshots.").hasSize(3); + assertThat(table.snapshot(secondSnapshot.snapshotId())) + .as("First snapshot should be present.") + .isNotNull(); + checkExpirationResults(0L, 0L, 0L, 0L, 1L, result); + } + + @TestTemplate + public void testExpireSnapshotsWithDisabledGarbageCollection() { + table.updateProperties().set(TableProperties.GC_ENABLED, "false").commit(); + + table.newAppend().appendFile(FILE_A).commit(); + + assertThatThrownBy(() -> SparkActions.get().expireSnapshots(table)) + .isInstanceOf(ValidationException.class) + .hasMessage( + "Cannot expire snapshots: GC is disabled (deleting files may corrupt other tables)"); + } + + @TestTemplate + public void testExpireOlderThanMultipleCalls() { + table + .newAppend() + .appendFile(FILE_A) // data_bucket=0 + .commit(); + + table + .newAppend() + .appendFile(FILE_B) // data_bucket=1 + .commit(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + table + .newAppend() + .appendFile(FILE_C) // data_bucket=2 + .commit(); + + Snapshot thirdSnapshot = table.currentSnapshot(); + + // Retain last 2 snapshots and expire older than t3 + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .expireOlderThan(secondSnapshot.timestampMillis()) + .expireOlderThan(thirdSnapshot.timestampMillis()) + .execute(); + + assertThat(table.snapshots()).as("Should have one snapshot.").hasSize(1); + assertThat(table.snapshot(secondSnapshot.snapshotId())) + .as("Second snapshot should not present.") + .isNull(); + checkExpirationResults(0L, 0L, 0L, 0L, 2L, result); + } + + @TestTemplate + public void testRetainLastMultipleCalls() { + table + .newAppend() + .appendFile(FILE_A) // data_bucket=0 + .commit(); + + table + .newAppend() + .appendFile(FILE_B) // data_bucket=1 + .commit(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + table + .newAppend() + .appendFile(FILE_C) // data_bucket=2 + .commit(); + + long t3 = rightAfterSnapshot(); + + // Retain last 2 snapshots and expire older than t3 + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .expireOlderThan(t3) + .retainLast(2) + .retainLast(1) + .execute(); + + assertThat(table.snapshots()).as("Should have one snapshot.").hasSize(1); + assertThat(table.snapshot(secondSnapshot.snapshotId())) + .as("Second snapshot should not present.") + .isNull(); + checkExpirationResults(0L, 0L, 0L, 0L, 2L, result); + } + + @TestTemplate + public void testRetainZeroSnapshots() { + assertThatThrownBy(() -> SparkActions.get().expireSnapshots(table).retainLast(0).execute()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Number of snapshots to retain must be at least 1, cannot be: 0"); + } + + @TestTemplate + public void testScanExpiredManifestInValidSnapshotAppend() { + table.newAppend().appendFile(FILE_A).appendFile(FILE_B).commit(); + + table.newOverwrite().addFile(FILE_C).deleteFile(FILE_A).commit(); + + table.newAppend().appendFile(FILE_D).commit(); + + long t3 = rightAfterSnapshot(); + + Set deletedFiles = Sets.newHashSet(); + + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .expireOlderThan(t3) + .deleteWith(deletedFiles::add) + .execute(); + + assertThat(deletedFiles).as("FILE_A should be deleted").contains(FILE_A.location()); + checkExpirationResults(1L, 0L, 0L, 1L, 2L, result); + } + + @TestTemplate + public void testScanExpiredManifestInValidSnapshotFastAppend() { + table + .updateProperties() + .set(TableProperties.MANIFEST_MERGE_ENABLED, "true") + .set(TableProperties.MANIFEST_MIN_MERGE_COUNT, "1") + .commit(); + + table.newAppend().appendFile(FILE_A).appendFile(FILE_B).commit(); + + table.newOverwrite().addFile(FILE_C).deleteFile(FILE_A).commit(); + + table.newFastAppend().appendFile(FILE_D).commit(); + + long t3 = rightAfterSnapshot(); + + Set deletedFiles = Sets.newHashSet(); + + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .expireOlderThan(t3) + .deleteWith(deletedFiles::add) + .execute(); + + assertThat(deletedFiles).as("FILE_A should be deleted").contains(FILE_A.location()); + checkExpirationResults(1L, 0L, 0L, 1L, 2L, result); + } + + /** + * Test on table below, and expiring the staged commit `B` using `expireOlderThan` API. Table: A - + * C ` B (staged) + */ + @TestTemplate + public void testWithExpiringDanglingStageCommit() { + // `A` commit + table.newAppend().appendFile(FILE_A).commit(); + + // `B` staged commit + table.newAppend().appendFile(FILE_B).stageOnly().commit(); + + TableMetadata base = ((BaseTable) table).operations().current(); + Snapshot snapshotA = base.snapshots().get(0); + Snapshot snapshotB = base.snapshots().get(1); + + // `C` commit + table.newAppend().appendFile(FILE_C).commit(); + + Set deletedFiles = Sets.newHashSet(); + + // Expire all commits including dangling staged snapshot. + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .deleteWith(deletedFiles::add) + .expireOlderThan(snapshotB.timestampMillis() + 1) + .execute(); + + checkExpirationResults(1L, 0L, 0L, 1L, 2L, result); + + Set expectedDeletes = Sets.newHashSet(); + expectedDeletes.add(snapshotA.manifestListLocation()); + + // Files should be deleted of dangling staged snapshot + snapshotB + .addedDataFiles(table.io()) + .forEach( + i -> { + expectedDeletes.add(i.location()); + }); + + // ManifestList should be deleted too + expectedDeletes.add(snapshotB.manifestListLocation()); + snapshotB + .dataManifests(table.io()) + .forEach( + file -> { + // Only the manifest of B should be deleted. + if (file.snapshotId() == snapshotB.snapshotId()) { + expectedDeletes.add(file.path()); + } + }); + assertThat(expectedDeletes) + .as("Files deleted count should be expected") + .hasSameSizeAs(deletedFiles); + // Take the diff + expectedDeletes.removeAll(deletedFiles); + assertThat(expectedDeletes).as("Exactly same files should be deleted").isEmpty(); + } + + /** + * Expire cherry-pick the commit as shown below, when `B` is in table's current state Table: A - B + * - C <--current snapshot `- D (source=B) + */ + @TestTemplate + public void testWithCherryPickTableSnapshot() { + // `A` commit + table.newAppend().appendFile(FILE_A).commit(); + Snapshot snapshotA = table.currentSnapshot(); + + // `B` commit + Set deletedAFiles = Sets.newHashSet(); + table.newOverwrite().addFile(FILE_B).deleteFile(FILE_A).deleteWith(deletedAFiles::add).commit(); + assertThat(deletedAFiles).as("No files should be physically deleted").isEmpty(); + + // pick the snapshot 'B` + Snapshot snapshotB = table.currentSnapshot(); + + // `C` commit to let cherry-pick take effect, and avoid fast-forward of `B` with cherry-pick + table.newAppend().appendFile(FILE_C).commit(); + Snapshot snapshotC = table.currentSnapshot(); + + // Move the table back to `A` + table.manageSnapshots().setCurrentSnapshot(snapshotA.snapshotId()).commit(); + + // Generate A -> `D (B)` + table.manageSnapshots().cherrypick(snapshotB.snapshotId()).commit(); + Snapshot snapshotD = table.currentSnapshot(); + + // Move the table back to `C` + table.manageSnapshots().setCurrentSnapshot(snapshotC.snapshotId()).commit(); + List deletedFiles = Lists.newArrayList(); + + // Expire `C` + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .deleteWith(deletedFiles::add) + .expireOlderThan(snapshotC.timestampMillis() + 1) + .execute(); + + // Make sure no dataFiles are deleted for the B, C, D snapshot + Lists.newArrayList(snapshotB, snapshotC, snapshotD) + .forEach( + i -> { + i.addedDataFiles(table.io()) + .forEach( + item -> { + assertThat(deletedFiles).doesNotContain(item.location()); + }); + }); + + checkExpirationResults(1L, 0L, 0L, 2L, 2L, result); + } + + /** + * Test on table below, and expiring `B` which is not in current table state. 1) Expire `B` 2) All + * commit Table: A - C - D (B) ` B (staged) + */ + @TestTemplate + public void testWithExpiringStagedThenCherrypick() { + // `A` commit + table.newAppend().appendFile(FILE_A).commit(); + + // `B` commit + table.newAppend().appendFile(FILE_B).stageOnly().commit(); + + // pick the snapshot that's staged but not committed + TableMetadata base = ((BaseTable) table).operations().current(); + Snapshot snapshotB = base.snapshots().get(1); + + // `C` commit to let cherry-pick take effect, and avoid fast-forward of `B` with cherry-pick + table.newAppend().appendFile(FILE_C).commit(); + + // `D (B)` cherry-pick commit + table.manageSnapshots().cherrypick(snapshotB.snapshotId()).commit(); + + base = ((BaseTable) table).operations().current(); + Snapshot snapshotD = base.snapshots().get(3); + + List deletedFiles = Lists.newArrayList(); + + // Expire `B` commit. + ExpireSnapshots.Result firstResult = + SparkActions.get() + .expireSnapshots(table) + .deleteWith(deletedFiles::add) + .expireSnapshotId(snapshotB.snapshotId()) + .execute(); + + // Make sure no dataFiles are deleted for the staged snapshot + Lists.newArrayList(snapshotB) + .forEach( + i -> { + i.addedDataFiles(table.io()) + .forEach( + item -> { + assertThat(deletedFiles).doesNotContain(item.location()); + }); + }); + checkExpirationResults(0L, 0L, 0L, 1L, 1L, firstResult); + + // Expire all snapshots including cherry-pick + ExpireSnapshots.Result secondResult = + SparkActions.get() + .expireSnapshots(table) + .deleteWith(deletedFiles::add) + .expireOlderThan(table.currentSnapshot().timestampMillis() + 1) + .execute(); + + // Make sure no dataFiles are deleted for the staged and cherry-pick + Lists.newArrayList(snapshotB, snapshotD) + .forEach( + i -> { + i.addedDataFiles(table.io()) + .forEach( + item -> { + assertThat(deletedFiles).doesNotContain(item.location()); + }); + }); + checkExpirationResults(0L, 0L, 0L, 0L, 2L, secondResult); + } + + @TestTemplate + public void testExpireOlderThan() { + table.newAppend().appendFile(FILE_A).commit(); + + Snapshot firstSnapshot = table.currentSnapshot(); + + rightAfterSnapshot(); + + table.newAppend().appendFile(FILE_B).commit(); + + long snapshotId = table.currentSnapshot().snapshotId(); + + long tAfterCommits = rightAfterSnapshot(); + + Set deletedFiles = Sets.newHashSet(); + + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .expireOlderThan(tAfterCommits) + .deleteWith(deletedFiles::add) + .execute(); + + assertThat(table.currentSnapshot().snapshotId()) + .as("Expire should not change current snapshot.") + .isEqualTo(snapshotId); + assertThat(table.snapshot(firstSnapshot.snapshotId())) + .as("Expire should remove the oldest snapshot.") + .isNull(); + assertThat(deletedFiles) + .as("Should remove only the expired manifest list location.") + .isEqualTo(Sets.newHashSet(firstSnapshot.manifestListLocation())); + + checkExpirationResults(0, 0, 0, 0, 1, result); + } + + @TestTemplate + public void testExpireOlderThanWithDelete() { + table.newAppend().appendFile(FILE_A).commit(); + + Snapshot firstSnapshot = table.currentSnapshot(); + assertThat(firstSnapshot.allManifests(table.io())).as("Should create one manifest").hasSize(1); + + rightAfterSnapshot(); + + table.newDelete().deleteFile(FILE_A).commit(); + + Snapshot secondSnapshot = table.currentSnapshot(); + assertThat(secondSnapshot.allManifests(table.io())) + .as("Should create replace manifest with a rewritten manifest") + .hasSize(1); + + table.newAppend().appendFile(FILE_B).commit(); + + rightAfterSnapshot(); + + long snapshotId = table.currentSnapshot().snapshotId(); + + long tAfterCommits = rightAfterSnapshot(); + + Set deletedFiles = Sets.newHashSet(); + + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .expireOlderThan(tAfterCommits) + .deleteWith(deletedFiles::add) + .execute(); + + assertThat(table.currentSnapshot().snapshotId()) + .as("Expire should not change current snapshot.") + .isEqualTo(snapshotId); + assertThat(table.snapshot(firstSnapshot.snapshotId())) + .as("Expire should remove the oldest snapshot.") + .isNull(); + assertThat(table.snapshot(secondSnapshot.snapshotId())) + .as("Expire should remove the second oldest snapshot.") + .isNull(); + assertThat(deletedFiles) + .as("Should remove expired manifest lists and deleted data file.") + .isEqualTo( + Sets.newHashSet( + firstSnapshot.manifestListLocation(), // snapshot expired + firstSnapshot + .allManifests(table.io()) + .get(0) + .path(), // manifest was rewritten for delete + secondSnapshot.manifestListLocation(), // snapshot expired + secondSnapshot + .allManifests(table.io()) + .get(0) + .path(), // manifest contained only deletes, was dropped + FILE_A.path()) // deleted + ); + + checkExpirationResults(1, 0, 0, 2, 2, result); + } + + @TestTemplate + public void testExpireOlderThanWithDeleteInMergedManifests() { + // merge every commit + table.updateProperties().set(TableProperties.MANIFEST_MIN_MERGE_COUNT, "0").commit(); + + table.newAppend().appendFile(FILE_A).appendFile(FILE_B).commit(); + + Snapshot firstSnapshot = table.currentSnapshot(); + assertThat(firstSnapshot.allManifests(table.io())).as("Should create one manifest").hasSize(1); + + rightAfterSnapshot(); + + table + .newDelete() + .deleteFile(FILE_A) // FILE_B is still in the dataset + .commit(); + + Snapshot secondSnapshot = table.currentSnapshot(); + assertThat(secondSnapshot.allManifests(table.io())) + .as("Should replace manifest with a rewritten manifest") + .hasSize(1); + table + .newFastAppend() // do not merge to keep the last snapshot's manifest valid + .appendFile(FILE_C) + .commit(); + + rightAfterSnapshot(); + + long snapshotId = table.currentSnapshot().snapshotId(); + + long tAfterCommits = rightAfterSnapshot(); + + Set deletedFiles = Sets.newHashSet(); + + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .expireOlderThan(tAfterCommits) + .deleteWith(deletedFiles::add) + .execute(); + + assertThat(table.currentSnapshot().snapshotId()) + .as("Expire should not change current snapshot.") + .isEqualTo(snapshotId); + assertThat(table.snapshot(firstSnapshot.snapshotId())) + .as("Expire should remove the oldest snapshot.") + .isNull(); + assertThat(table.snapshot(secondSnapshot.snapshotId())) + .as("Expire should remove the second oldest snapshot.") + .isNull(); + + assertThat(deletedFiles) + .as("Should remove expired manifest lists and deleted data file.") + .isEqualTo( + Sets.newHashSet( + firstSnapshot.manifestListLocation(), // snapshot expired + firstSnapshot + .allManifests(table.io()) + .get(0) + .path(), // manifest was rewritten for delete + secondSnapshot.manifestListLocation(), // snapshot expired + FILE_A.path()) // deleted + ); + checkExpirationResults(1, 0, 0, 1, 2, result); + } + + @TestTemplate + public void testExpireOlderThanWithRollback() { + // merge every commit + table.updateProperties().set(TableProperties.MANIFEST_MIN_MERGE_COUNT, "0").commit(); + + table.newAppend().appendFile(FILE_A).appendFile(FILE_B).commit(); + + Snapshot firstSnapshot = table.currentSnapshot(); + assertThat(firstSnapshot.allManifests(table.io())).as("Should create one manifest").hasSize(1); + + rightAfterSnapshot(); + + table.newDelete().deleteFile(FILE_B).commit(); + + Snapshot secondSnapshot = table.currentSnapshot(); + Set secondSnapshotManifests = + Sets.newHashSet(secondSnapshot.allManifests(table.io())); + secondSnapshotManifests.removeAll(firstSnapshot.allManifests(table.io())); + assertThat(secondSnapshotManifests).as("Should add one new manifest for append").hasSize(1); + + table.manageSnapshots().rollbackTo(firstSnapshot.snapshotId()).commit(); + + long tAfterCommits = rightAfterSnapshot(secondSnapshot.snapshotId()); + + long snapshotId = table.currentSnapshot().snapshotId(); + + Set deletedFiles = Sets.newHashSet(); + + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .expireOlderThan(tAfterCommits) + .deleteWith(deletedFiles::add) + .execute(); + + assertThat(table.currentSnapshot().snapshotId()) + .as("Expire should not change current snapshot.") + .isEqualTo(snapshotId); + + assertThat(table.snapshot(firstSnapshot.snapshotId())) + .as("Expire should keep the oldest snapshot, current.") + .isNotNull(); + assertThat(table.snapshot(secondSnapshot.snapshotId())) + .as("Expire should remove the orphaned snapshot.") + .isNull(); + + assertThat(deletedFiles) + .as("Should remove expired manifest lists and reverted appended data file") + .isEqualTo( + Sets.newHashSet( + secondSnapshot.manifestListLocation(), // snapshot expired + Iterables.getOnlyElement(secondSnapshotManifests) + .path()) // manifest is no longer referenced + ); + + checkExpirationResults(0, 0, 0, 1, 1, result); + } + + @TestTemplate + public void testExpireOlderThanWithRollbackAndMergedManifests() { + table.newAppend().appendFile(FILE_A).commit(); + + Snapshot firstSnapshot = table.currentSnapshot(); + assertThat(firstSnapshot.allManifests(table.io())).as("Should create one manifest").hasSize(1); + rightAfterSnapshot(); + + table.newAppend().appendFile(FILE_B).commit(); + + Snapshot secondSnapshot = table.currentSnapshot(); + Set secondSnapshotManifests = + Sets.newHashSet(secondSnapshot.allManifests(table.io())); + secondSnapshotManifests.removeAll(firstSnapshot.allManifests(table.io())); + assertThat(secondSnapshotManifests).as("Should add one new manifest for append").hasSize(1); + + table.manageSnapshots().rollbackTo(firstSnapshot.snapshotId()).commit(); + + long tAfterCommits = rightAfterSnapshot(secondSnapshot.snapshotId()); + + long snapshotId = table.currentSnapshot().snapshotId(); + + Set deletedFiles = Sets.newHashSet(); + + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .expireOlderThan(tAfterCommits) + .deleteWith(deletedFiles::add) + .execute(); + + assertThat(table.currentSnapshot().snapshotId()) + .as("Expire should not change current snapshot.") + .isEqualTo(snapshotId); + + assertThat(table.snapshot(firstSnapshot.snapshotId())) + .as("Expire should keep the oldest snapshot, current.") + .isNotNull(); + assertThat(table.snapshot(secondSnapshot.snapshotId())) + .as("Expire should remove the orphaned snapshot.") + .isNull(); + + assertThat(deletedFiles) + .as("Should remove expired manifest lists and reverted appended data file") + .isEqualTo( + Sets.newHashSet( + secondSnapshot.manifestListLocation(), // snapshot expired + Iterables.getOnlyElement(secondSnapshotManifests) + .path(), // manifest is no longer referenced + FILE_B.path()) // added, but rolled back + ); + + checkExpirationResults(1, 0, 0, 1, 1, result); + } + + @TestTemplate + public void testExpireOlderThanWithDeleteFile() { + assumeThat(formatVersion).isGreaterThanOrEqualTo(2); + table.updateProperties().set(TableProperties.MANIFEST_MERGE_ENABLED, "false").commit(); + + // Add Data File + table.newAppend().appendFile(FILE_A).commit(); + Snapshot firstSnapshot = table.currentSnapshot(); + + // Add POS Delete + DeleteFile fileADeletes = fileADeletes(); + table.newRowDelta().addDeletes(fileADeletes).commit(); + Snapshot secondSnapshot = table.currentSnapshot(); + + // Add EQ Delete + table.newRowDelta().addDeletes(FILE_A_EQ_DELETES).commit(); + Snapshot thirdSnapshot = table.currentSnapshot(); + + // Move files to DELETED + table.newDelete().deleteFromRowFilter(Expressions.alwaysTrue()).commit(); + Snapshot fourthSnapshot = table.currentSnapshot(); + + long afterAllDeleted = rightAfterSnapshot(); + + table.newAppend().appendFile(FILE_B).commit(); + + Set deletedFiles = Sets.newHashSet(); + + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .expireOlderThan(afterAllDeleted) + .deleteWith(deletedFiles::add) + .execute(); + + Set expectedDeletes = + Sets.newHashSet( + firstSnapshot.manifestListLocation(), + secondSnapshot.manifestListLocation(), + thirdSnapshot.manifestListLocation(), + fourthSnapshot.manifestListLocation(), + FILE_A.location(), + fileADeletes.location(), + FILE_A_EQ_DELETES.location()); + + expectedDeletes.addAll( + thirdSnapshot.allManifests(table.io()).stream() + .map(ManifestFile::path) + .map(CharSequence::toString) + .collect(Collectors.toSet())); + // Delete operation (fourth snapshot) generates new manifest files + expectedDeletes.addAll( + fourthSnapshot.allManifests(table.io()).stream() + .map(ManifestFile::path) + .map(CharSequence::toString) + .collect(Collectors.toSet())); + + assertThat(deletedFiles) + .as("Should remove expired manifest lists and deleted data file") + .isEqualTo(expectedDeletes); + + checkExpirationResults(1, 1, 1, 6, 4, result); + } + + @TestTemplate + public void testExpireOnEmptyTable() { + Set deletedFiles = Sets.newHashSet(); + + // table has no data, testing ExpireSnapshots should not fail with no snapshot + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .expireOlderThan(System.currentTimeMillis()) + .deleteWith(deletedFiles::add) + .execute(); + + checkExpirationResults(0, 0, 0, 0, 0, result); + } + + @TestTemplate + public void testExpireAction() { + table.newAppend().appendFile(FILE_A).commit(); + + Snapshot firstSnapshot = table.currentSnapshot(); + + rightAfterSnapshot(); + + table.newAppend().appendFile(FILE_B).commit(); + + long snapshotId = table.currentSnapshot().snapshotId(); + + long tAfterCommits = rightAfterSnapshot(); + + Set deletedFiles = Sets.newHashSet(); + + ExpireSnapshotsSparkAction action = + SparkActions.get() + .expireSnapshots(table) + .expireOlderThan(tAfterCommits) + .deleteWith(deletedFiles::add); + Dataset pendingDeletes = action.expireFiles(); + + List pending = pendingDeletes.collectAsList(); + + assertThat(table.currentSnapshot().snapshotId()) + .as("Should not change current snapshot.") + .isEqualTo(snapshotId); + + assertThat(table.snapshot(firstSnapshot.snapshotId())) + .as("Should remove the oldest snapshot") + .isNull(); + assertThat(pending).as("Pending deletes should contain one row").hasSize(1); + + assertThat(pending.get(0).getPath()) + .as("Pending delete should be the expired manifest list location") + .isEqualTo(firstSnapshot.manifestListLocation()); + + assertThat(pending.get(0).getType()) + .as("Pending delete should be a manifest list") + .isEqualTo("Manifest List"); + + assertThat(deletedFiles).as("Should not delete any files").hasSize(0); + + assertThat(action.expireFiles().count()) + .as("Multiple calls to expire should return the same count of deleted files") + .isEqualTo(pendingDeletes.count()); + } + + @TestTemplate + public void testUseLocalIterator() { + table.newFastAppend().appendFile(FILE_A).commit(); + + table.newOverwrite().deleteFile(FILE_A).addFile(FILE_B).commit(); + + table.newFastAppend().appendFile(FILE_C).commit(); + + long end = rightAfterSnapshot(); + + int jobsBeforeStreamResults = spark.sparkContext().dagScheduler().nextJobId().get(); + + withSQLConf( + ImmutableMap.of("spark.sql.adaptive.enabled", "false"), + () -> { + ExpireSnapshots.Result results = + SparkActions.get() + .expireSnapshots(table) + .expireOlderThan(end) + .option("stream-results", "true") + .execute(); + + int jobsAfterStreamResults = spark.sparkContext().dagScheduler().nextJobId().get(); + int jobsRunDuringStreamResults = jobsAfterStreamResults - jobsBeforeStreamResults; + + checkExpirationResults(1L, 0L, 0L, 1L, 2L, results); + + assertThat(jobsRunDuringStreamResults) + .as( + "Expected total number of jobs with stream-results should match the expected number") + .isEqualTo(4L); + }); + } + + @TestTemplate + public void testExpireAfterExecute() { + table + .newAppend() + .appendFile(FILE_A) // data_bucket=0 + .commit(); + + rightAfterSnapshot(); + + table + .newAppend() + .appendFile(FILE_B) // data_bucket=1 + .commit(); + + table + .newAppend() + .appendFile(FILE_C) // data_bucket=2 + .commit(); + + long t3 = rightAfterSnapshot(); + + ExpireSnapshotsSparkAction action = SparkActions.get().expireSnapshots(table); + + action.expireOlderThan(t3).retainLast(2); + + ExpireSnapshots.Result result = action.execute(); + checkExpirationResults(0L, 0L, 0L, 0L, 1L, result); + + List typedExpiredFiles = action.expireFiles().collectAsList(); + assertThat(typedExpiredFiles).as("Expired results must match").hasSize(1); + + List untypedExpiredFiles = action.expireFiles().collectAsList(); + assertThat(untypedExpiredFiles).as("Expired results must match").hasSize(1); + } + + @TestTemplate + public void testExpireFileDeletionMostExpired() { + textExpireAllCheckFilesDeleted(5, 2); + } + + @TestTemplate + public void testExpireFileDeletionMostRetained() { + textExpireAllCheckFilesDeleted(2, 5); + } + + public void textExpireAllCheckFilesDeleted(int dataFilesExpired, int dataFilesRetained) { + // Add data files to be expired + Set dataFiles = Sets.newHashSet(); + for (int i = 0; i < dataFilesExpired; i++) { + DataFile df = + DataFiles.builder(SPEC) + .withPath(String.format("/path/to/data-expired-%d.parquet", i)) + .withFileSizeInBytes(10) + .withPartitionPath("c1=1") + .withRecordCount(1) + .build(); + dataFiles.add(df.location()); + table.newFastAppend().appendFile(df).commit(); + } + + // Delete them all, these will be deleted on expire snapshot + table.newDelete().deleteFromRowFilter(Expressions.alwaysTrue()).commit(); + // Clears "DELETED" manifests + table.newDelete().deleteFromRowFilter(Expressions.alwaysTrue()).commit(); + + Set manifestsBefore = TestHelpers.reachableManifestPaths(table); + + // Add data files to be retained, which are not deleted. + for (int i = 0; i < dataFilesRetained; i++) { + DataFile df = + DataFiles.builder(SPEC) + .withPath(String.format("/path/to/data-retained-%d.parquet", i)) + .withFileSizeInBytes(10) + .withPartitionPath("c1=1") + .withRecordCount(1) + .build(); + table.newFastAppend().appendFile(df).commit(); + } + + long end = rightAfterSnapshot(); + + Set expectedDeletes = Sets.newHashSet(); + expectedDeletes.addAll(ReachableFileUtil.manifestListLocations(table)); + // all snapshot manifest lists except current will be deleted + expectedDeletes.remove(table.currentSnapshot().manifestListLocation()); + expectedDeletes.addAll( + manifestsBefore); // new manifests are reachable from current snapshot and not deleted + expectedDeletes.addAll( + dataFiles); // new data files are reachable from current snapshot and not deleted + + Set deletedFiles = Sets.newHashSet(); + SparkActions.get() + .expireSnapshots(table) + .expireOlderThan(end) + .deleteWith(deletedFiles::add) + .execute(); + + assertThat(deletedFiles) + .as("All reachable files before expiration should be deleted") + .isEqualTo(expectedDeletes); + } + + @TestTemplate + public void testExpireSomeCheckFilesDeleted() { + + table.newAppend().appendFile(FILE_A).commit(); + + table.newAppend().appendFile(FILE_B).commit(); + + table.newAppend().appendFile(FILE_C).commit(); + + table.newDelete().deleteFile(FILE_A).commit(); + + long after = rightAfterSnapshot(); + waitUntilAfter(after); + + table.newAppend().appendFile(FILE_D).commit(); + + table.newDelete().deleteFile(FILE_B).commit(); + + Set deletedFiles = Sets.newHashSet(); + SparkActions.get() + .expireSnapshots(table) + .expireOlderThan(after) + .deleteWith(deletedFiles::add) + .execute(); + + // C, D should be retained (live) + // B should be retained (previous snapshot points to it) + // A should be deleted + assertThat(deletedFiles).contains(FILE_A.location()); + assertThat(deletedFiles).doesNotContain(FILE_B.location()); + assertThat(deletedFiles).doesNotContain(FILE_C.location()); + assertThat(deletedFiles).doesNotContain(FILE_D.location()); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestMigrateTableAction.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestMigrateTableAction.java new file mode 100644 index 000000000000..94afa50cf4b8 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestMigrateTableAction.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assumptions.assumeThat; + +import java.io.IOException; +import java.nio.file.Files; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.spark.CatalogTestBase; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestMigrateTableAction extends CatalogTestBase { + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS %s_BACKUP_", tableName); + } + + @TestTemplate + public void testMigrateWithParallelTasks() throws IOException { + assumeThat(catalogName).isEqualToIgnoringCase("spark_catalog"); + String location = Files.createTempDirectory(temp, "junit").toFile().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + tableName, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + AtomicInteger migrationThreadsIndex = new AtomicInteger(0); + SparkActions.get() + .migrateTable(tableName) + .executeWith( + Executors.newFixedThreadPool( + 4, + runnable -> { + Thread thread = new Thread(runnable); + thread.setName("table-migration-" + migrationThreadsIndex.getAndIncrement()); + thread.setDaemon(true); + return thread; + })) + .execute(); + assertThat(migrationThreadsIndex.get()).isEqualTo(2); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestRemoveDanglingDeleteAction.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestRemoveDanglingDeleteAction.java new file mode 100644 index 000000000000..e58966cfea3f --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestRemoveDanglingDeleteAction.java @@ -0,0 +1,491 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.File; +import java.util.Arrays; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileGenerationUtil; +import org.apache.iceberg.FileMetadata; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.actions.RemoveDanglingDeleteFiles; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.TestBase; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Encoders; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; +import scala.Tuple2; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestRemoveDanglingDeleteAction extends TestBase { + + private static final HadoopTables TABLES = new HadoopTables(new Configuration()); + private static final Schema SCHEMA = + new Schema( + optional(1, "c1", Types.StringType.get()), + optional(2, "c2", Types.StringType.get()), + optional(3, "c3", Types.StringType.get())); + + private static final PartitionSpec SPEC = PartitionSpec.builderFor(SCHEMA).identity("c1").build(); + + static final DataFile FILE_A = + DataFiles.builder(SPEC) + .withPath("/path/to/data-a.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=a") // easy way to set partition data for now + .withRecordCount(1) + .build(); + static final DataFile FILE_A2 = + DataFiles.builder(SPEC) + .withPath("/path/to/data-a.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=a") // easy way to set partition data for now + .withRecordCount(1) + .build(); + static final DataFile FILE_B = + DataFiles.builder(SPEC) + .withPath("/path/to/data-b.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=b") // easy way to set partition data for now + .withRecordCount(1) + .build(); + static final DataFile FILE_B2 = + DataFiles.builder(SPEC) + .withPath("/path/to/data-b.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=b") // easy way to set partition data for now + .withRecordCount(1) + .build(); + static final DataFile FILE_C = + DataFiles.builder(SPEC) + .withPath("/path/to/data-c.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=c") // easy way to set partition data for now + .withRecordCount(1) + .build(); + static final DataFile FILE_C2 = + DataFiles.builder(SPEC) + .withPath("/path/to/data-c.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=c") // easy way to set partition data for now + .withRecordCount(1) + .build(); + static final DataFile FILE_D = + DataFiles.builder(SPEC) + .withPath("/path/to/data-d.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=d") // easy way to set partition data for now + .withRecordCount(1) + .build(); + static final DataFile FILE_D2 = + DataFiles.builder(SPEC) + .withPath("/path/to/data-d.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=d") // easy way to set partition data for now + .withRecordCount(1) + .build(); + static final DeleteFile FILE_A_POS_DELETES = + FileMetadata.deleteFileBuilder(SPEC) + .ofPositionDeletes() + .withPath("/path/to/data-a-pos-deletes.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=a") // easy way to set partition data for now + .withRecordCount(1) + .build(); + static final DeleteFile FILE_A2_POS_DELETES = + FileMetadata.deleteFileBuilder(SPEC) + .ofPositionDeletes() + .withPath("/path/to/data-a2-pos-deletes.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=a") // easy way to set partition data for now + .withRecordCount(1) + .build(); + static final DeleteFile FILE_A_EQ_DELETES = + FileMetadata.deleteFileBuilder(SPEC) + .ofEqualityDeletes() + .withPath("/path/to/data-a-eq-deletes.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=a") // easy way to set partition data for now + .withRecordCount(1) + .build(); + static final DeleteFile FILE_A2_EQ_DELETES = + FileMetadata.deleteFileBuilder(SPEC) + .ofEqualityDeletes() + .withPath("/path/to/data-a2-eq-deletes.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=a") // easy way to set partition data for now + .withRecordCount(1) + .build(); + static final DeleteFile FILE_B_POS_DELETES = + FileMetadata.deleteFileBuilder(SPEC) + .ofPositionDeletes() + .withPath("/path/to/data-b-pos-deletes.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=b") // easy way to set partition data for now + .withRecordCount(1) + .build(); + static final DeleteFile FILE_B2_POS_DELETES = + FileMetadata.deleteFileBuilder(SPEC) + .ofPositionDeletes() + .withPath("/path/to/data-b2-pos-deletes.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=b") // easy way to set partition data for now + .withRecordCount(1) + .build(); + static final DeleteFile FILE_B_EQ_DELETES = + FileMetadata.deleteFileBuilder(SPEC) + .ofEqualityDeletes() + .withPath("/path/to/data-b-eq-deletes.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=b") // easy way to set partition data for now + .withRecordCount(1) + .build(); + static final DeleteFile FILE_B2_EQ_DELETES = + FileMetadata.deleteFileBuilder(SPEC) + .ofEqualityDeletes() + .withPath("/path/to/data-b2-eq-deletes.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=b") // easy way to set partition data for now + .withRecordCount(1) + .build(); + + static final DataFile FILE_UNPARTITIONED = + DataFiles.builder(PartitionSpec.unpartitioned()) + .withPath("/path/to/data-unpartitioned.parquet") + .withFileSizeInBytes(10) + .withRecordCount(1) + .build(); + static final DeleteFile FILE_UNPARTITIONED_POS_DELETE = + FileMetadata.deleteFileBuilder(PartitionSpec.unpartitioned()) + .ofEqualityDeletes() + .withPath("/path/to/data-unpartitioned-pos-deletes.parquet") + .withFileSizeInBytes(10) + .withRecordCount(1) + .build(); + static final DeleteFile FILE_UNPARTITIONED_EQ_DELETE = + FileMetadata.deleteFileBuilder(PartitionSpec.unpartitioned()) + .ofEqualityDeletes() + .withPath("/path/to/data-unpartitioned-eq-deletes.parquet") + .withFileSizeInBytes(10) + .withRecordCount(1) + .build(); + + @TempDir private File tableDir; + @Parameter private int formatVersion; + + @Parameters(name = "formatVersion = {0}") + protected static List parameters() { + return Arrays.asList(2, 3); + } + + private String tableLocation = null; + private Table table; + + @BeforeEach + public void before() throws Exception { + this.tableLocation = tableDir.toURI().toString(); + } + + @AfterEach + public void after() { + TABLES.dropTable(tableLocation); + } + + private void setupPartitionedTable() { + this.table = + TABLES.create( + SCHEMA, + SPEC, + ImmutableMap.of(TableProperties.FORMAT_VERSION, String.valueOf(formatVersion)), + tableLocation); + } + + private void setupUnpartitionedTable() { + this.table = + TABLES.create( + SCHEMA, + PartitionSpec.unpartitioned(), + ImmutableMap.of(TableProperties.FORMAT_VERSION, String.valueOf(formatVersion)), + tableLocation); + } + + private DeleteFile fileADeletes() { + return formatVersion >= 3 ? FileGenerationUtil.generateDV(table, FILE_A) : FILE_A_POS_DELETES; + } + + private DeleteFile fileA2Deletes() { + return formatVersion >= 3 ? FileGenerationUtil.generateDV(table, FILE_A2) : FILE_A2_POS_DELETES; + } + + private DeleteFile fileBDeletes() { + return formatVersion >= 3 ? FileGenerationUtil.generateDV(table, FILE_B) : FILE_B_POS_DELETES; + } + + private DeleteFile fileB2Deletes() { + return formatVersion >= 3 ? FileGenerationUtil.generateDV(table, FILE_B2) : FILE_B2_POS_DELETES; + } + + private DeleteFile fileUnpartitionedDeletes() { + return formatVersion >= 3 + ? FileGenerationUtil.generateDV(table, FILE_UNPARTITIONED) + : FILE_UNPARTITIONED_POS_DELETE; + } + + @TestTemplate + public void testPartitionedDeletesWithLesserSeqNo() { + setupPartitionedTable(); + + // Add Data Files + table.newAppend().appendFile(FILE_B).appendFile(FILE_C).appendFile(FILE_D).commit(); + + // Add Delete Files + DeleteFile fileADeletes = fileADeletes(); + DeleteFile fileA2Deletes = fileA2Deletes(); + DeleteFile fileBDeletes = fileBDeletes(); + DeleteFile fileB2Deletes = fileB2Deletes(); + table + .newRowDelta() + .addDeletes(fileADeletes) + .addDeletes(fileA2Deletes) + .addDeletes(fileBDeletes) + .addDeletes(fileB2Deletes) + .addDeletes(FILE_A_EQ_DELETES) + .addDeletes(FILE_A2_EQ_DELETES) + .addDeletes(FILE_B_EQ_DELETES) + .addDeletes(FILE_B2_EQ_DELETES) + .commit(); + + // Add More Data Files + table + .newAppend() + .appendFile(FILE_A2) + .appendFile(FILE_B2) + .appendFile(FILE_C2) + .appendFile(FILE_D2) + .commit(); + + List> actual = + spark + .read() + .format("iceberg") + .load(tableLocation + "#entries") + .select("sequence_number", "data_file.file_path") + .sort("sequence_number", "data_file.file_path") + .as(Encoders.tuple(Encoders.LONG(), Encoders.STRING())) + .collectAsList(); + List> expected = + ImmutableList.of( + Tuple2.apply(1L, FILE_B.path().toString()), + Tuple2.apply(1L, FILE_C.path().toString()), + Tuple2.apply(1L, FILE_D.path().toString()), + Tuple2.apply(2L, FILE_A_EQ_DELETES.path().toString()), + Tuple2.apply(2L, fileADeletes.path().toString()), + Tuple2.apply(2L, FILE_A2_EQ_DELETES.path().toString()), + Tuple2.apply(2L, fileA2Deletes.path().toString()), + Tuple2.apply(2L, FILE_B_EQ_DELETES.path().toString()), + Tuple2.apply(2L, fileBDeletes.path().toString()), + Tuple2.apply(2L, FILE_B2_EQ_DELETES.path().toString()), + Tuple2.apply(2L, fileB2Deletes.path().toString()), + Tuple2.apply(3L, FILE_A2.path().toString()), + Tuple2.apply(3L, FILE_B2.path().toString()), + Tuple2.apply(3L, FILE_C2.path().toString()), + Tuple2.apply(3L, FILE_D2.path().toString())); + assertThat(actual).containsExactlyInAnyOrderElementsOf(expected); + + RemoveDanglingDeleteFiles.Result result = + SparkActions.get().removeDanglingDeleteFiles(table).execute(); + + // All Delete files of the FILE A partition should be removed + // because there are no data files in partition with a lesser sequence number + + Set removedDeleteFiles = + StreamSupport.stream(result.removedDeleteFiles().spliterator(), false) + .map(DeleteFile::path) + .collect(Collectors.toSet()); + assertThat(removedDeleteFiles) + .as("Expected 4 delete files removed") + .hasSize(4) + .containsExactlyInAnyOrder( + fileADeletes.path(), + fileA2Deletes.path(), + FILE_A_EQ_DELETES.path(), + FILE_A2_EQ_DELETES.path()); + + List> actualAfter = + spark + .read() + .format("iceberg") + .load(tableLocation + "#entries") + .filter("status < 2") // live files + .select("sequence_number", "data_file.file_path") + .sort("sequence_number", "data_file.file_path") + .as(Encoders.tuple(Encoders.LONG(), Encoders.STRING())) + .collectAsList(); + List> expectedAfter = + ImmutableList.of( + Tuple2.apply(1L, FILE_B.path().toString()), + Tuple2.apply(1L, FILE_C.path().toString()), + Tuple2.apply(1L, FILE_D.path().toString()), + Tuple2.apply(2L, FILE_B_EQ_DELETES.path().toString()), + Tuple2.apply(2L, fileBDeletes.path().toString()), + Tuple2.apply(2L, FILE_B2_EQ_DELETES.path().toString()), + Tuple2.apply(2L, fileB2Deletes.path().toString()), + Tuple2.apply(3L, FILE_A2.path().toString()), + Tuple2.apply(3L, FILE_B2.path().toString()), + Tuple2.apply(3L, FILE_C2.path().toString()), + Tuple2.apply(3L, FILE_D2.path().toString())); + assertThat(actualAfter).containsExactlyInAnyOrderElementsOf(expectedAfter); + } + + @TestTemplate + public void testPartitionedDeletesWithEqSeqNo() { + setupPartitionedTable(); + + // Add Data Files + table.newAppend().appendFile(FILE_A).appendFile(FILE_C).appendFile(FILE_D).commit(); + + // Add Data Files with EQ and POS deletes + DeleteFile fileADeletes = fileADeletes(); + DeleteFile fileA2Deletes = fileA2Deletes(); + DeleteFile fileBDeletes = fileBDeletes(); + DeleteFile fileB2Deletes = fileB2Deletes(); + table + .newRowDelta() + .addRows(FILE_A2) + .addRows(FILE_B2) + .addRows(FILE_C2) + .addRows(FILE_D2) + .addDeletes(fileADeletes) + .addDeletes(fileA2Deletes) + .addDeletes(FILE_A_EQ_DELETES) + .addDeletes(FILE_A2_EQ_DELETES) + .addDeletes(fileBDeletes) + .addDeletes(fileB2Deletes) + .addDeletes(FILE_B_EQ_DELETES) + .addDeletes(FILE_B2_EQ_DELETES) + .commit(); + + List> actual = + spark + .read() + .format("iceberg") + .load(tableLocation + "#entries") + .select("sequence_number", "data_file.file_path") + .sort("sequence_number", "data_file.file_path") + .as(Encoders.tuple(Encoders.LONG(), Encoders.STRING())) + .collectAsList(); + List> expected = + ImmutableList.of( + Tuple2.apply(1L, FILE_A.path().toString()), + Tuple2.apply(1L, FILE_C.path().toString()), + Tuple2.apply(1L, FILE_D.path().toString()), + Tuple2.apply(2L, FILE_A_EQ_DELETES.path().toString()), + Tuple2.apply(2L, fileADeletes.path().toString()), + Tuple2.apply(2L, FILE_A2.path().toString()), + Tuple2.apply(2L, FILE_A2_EQ_DELETES.path().toString()), + Tuple2.apply(2L, fileA2Deletes.path().toString()), + Tuple2.apply(2L, FILE_B_EQ_DELETES.path().toString()), + Tuple2.apply(2L, fileBDeletes.path().toString()), + Tuple2.apply(2L, FILE_B2.path().toString()), + Tuple2.apply(2L, FILE_B2_EQ_DELETES.path().toString()), + Tuple2.apply(2L, fileB2Deletes.path().toString()), + Tuple2.apply(2L, FILE_C2.path().toString()), + Tuple2.apply(2L, FILE_D2.path().toString())); + assertThat(actual).containsExactlyInAnyOrderElementsOf(expected); + + RemoveDanglingDeleteFiles.Result result = + SparkActions.get().removeDanglingDeleteFiles(table).execute(); + + // Eq Delete files of the FILE B partition should be removed + // because there are no data files in partition with a lesser sequence number + Set removedDeleteFiles = + StreamSupport.stream(result.removedDeleteFiles().spliterator(), false) + .map(DeleteFile::path) + .collect(Collectors.toSet()); + assertThat(removedDeleteFiles) + .as("Expected two delete files removed") + .hasSize(2) + .containsExactlyInAnyOrder(FILE_B_EQ_DELETES.path(), FILE_B2_EQ_DELETES.path()); + + List> actualAfter = + spark + .read() + .format("iceberg") + .load(tableLocation + "#entries") + .filter("status < 2") // live files + .select("sequence_number", "data_file.file_path") + .sort("sequence_number", "data_file.file_path") + .as(Encoders.tuple(Encoders.LONG(), Encoders.STRING())) + .collectAsList(); + List> expectedAfter = + ImmutableList.of( + Tuple2.apply(1L, FILE_A.path().toString()), + Tuple2.apply(1L, FILE_C.path().toString()), + Tuple2.apply(1L, FILE_D.path().toString()), + Tuple2.apply(2L, FILE_A_EQ_DELETES.path().toString()), + Tuple2.apply(2L, fileADeletes.path().toString()), + Tuple2.apply(2L, FILE_A2.path().toString()), + Tuple2.apply(2L, FILE_A2_EQ_DELETES.path().toString()), + Tuple2.apply(2L, fileA2Deletes.path().toString()), + Tuple2.apply(2L, fileBDeletes.path().toString()), + Tuple2.apply(2L, FILE_B2.path().toString()), + Tuple2.apply(2L, fileB2Deletes.path().toString()), + Tuple2.apply(2L, FILE_C2.path().toString()), + Tuple2.apply(2L, FILE_D2.path().toString())); + assertThat(actualAfter).containsExactlyInAnyOrderElementsOf(expectedAfter); + } + + @TestTemplate + public void testUnpartitionedTable() { + setupUnpartitionedTable(); + + table + .newRowDelta() + .addDeletes(fileUnpartitionedDeletes()) + .addDeletes(FILE_UNPARTITIONED_EQ_DELETE) + .commit(); + table.newAppend().appendFile(FILE_UNPARTITIONED).commit(); + + RemoveDanglingDeleteFiles.Result result = + SparkActions.get().removeDanglingDeleteFiles(table).execute(); + assertThat(result.removedDeleteFiles()).as("No-op for unpartitioned tables").isEmpty(); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestRemoveOrphanFilesAction.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestRemoveOrphanFilesAction.java new file mode 100644 index 000000000000..d36898d4c464 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestRemoveOrphanFilesAction.java @@ -0,0 +1,1100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assumptions.assumeThat; + +import java.io.File; +import java.io.IOException; +import java.net.URI; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.sql.Timestamp; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.Files; +import org.apache.iceberg.GenericBlobMetadata; +import org.apache.iceberg.GenericStatisticsFile; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.StatisticsFile; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.Transaction; +import org.apache.iceberg.actions.DeleteOrphanFiles; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.hadoop.HadoopCatalog; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.hadoop.HiddenPathFilter; +import org.apache.iceberg.puffin.Blob; +import org.apache.iceberg.puffin.Puffin; +import org.apache.iceberg.puffin.PuffinWriter; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.spark.TestBase; +import org.apache.iceberg.spark.actions.DeleteOrphanFilesSparkAction.StringToFileURI; +import org.apache.iceberg.spark.source.FilePathLastModifiedRecord; +import org.apache.iceberg.spark.source.ThreeColumnRecord; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; + +@ExtendWith(ParameterizedTestExtension.class) +public abstract class TestRemoveOrphanFilesAction extends TestBase { + + private static final HadoopTables TABLES = new HadoopTables(new Configuration()); + protected static final Schema SCHEMA = + new Schema( + optional(1, "c1", Types.IntegerType.get()), + optional(2, "c2", Types.StringType.get()), + optional(3, "c3", Types.StringType.get())); + protected static final PartitionSpec SPEC = + PartitionSpec.builderFor(SCHEMA).truncate("c2", 2).identity("c3").build(); + + @TempDir private File tableDir = null; + protected String tableLocation = null; + protected Map properties; + @Parameter private int formatVersion; + + @Parameters(name = "formatVersion = {0}") + protected static List parameters() { + return Arrays.asList(2, 3); + } + + @BeforeEach + public void setupTableLocation() throws Exception { + this.tableLocation = tableDir.toURI().toString(); + properties = ImmutableMap.of(TableProperties.FORMAT_VERSION, String.valueOf(formatVersion)); + } + + @TestTemplate + public void testDryRun() throws IOException { + Table table = TABLES.create(SCHEMA, PartitionSpec.unpartitioned(), properties, tableLocation); + + List records = + Lists.newArrayList(new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA")); + + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class).coalesce(1); + + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + List validFiles = + spark + .read() + .format("iceberg") + .load(tableLocation + "#files") + .select("file_path") + .as(Encoders.STRING()) + .collectAsList(); + assertThat(validFiles).as("Should be 2 valid files").hasSize(2); + + df.write().mode("append").parquet(tableLocation + "/data"); + + Path dataPath = new Path(tableLocation + "/data"); + FileSystem fs = dataPath.getFileSystem(spark.sessionState().newHadoopConf()); + List allFiles = + Arrays.stream(fs.listStatus(dataPath, HiddenPathFilter.get())) + .filter(FileStatus::isFile) + .map(file -> file.getPath().toString()) + .collect(Collectors.toList()); + assertThat(allFiles).as("Should be 3 valid files").hasSize(3); + + List invalidFiles = Lists.newArrayList(allFiles); + invalidFiles.removeAll(validFiles); + assertThat(invalidFiles).as("Should be 1 invalid file").hasSize(1); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result1 = + actions.deleteOrphanFiles(table).deleteWith(s -> {}).execute(); + assertThat(result1.orphanFileLocations()) + .as("Default olderThan interval should be safe") + .isEmpty(); + + DeleteOrphanFiles.Result result2 = + actions + .deleteOrphanFiles(table) + .olderThan(System.currentTimeMillis()) + .deleteWith(s -> {}) + .execute(); + assertThat(result2.orphanFileLocations()) + .as("Action should find 1 file") + .isEqualTo(invalidFiles); + assertThat(fs.exists(new Path(invalidFiles.get(0)))) + .as("Invalid file should be present") + .isTrue(); + + DeleteOrphanFiles.Result result3 = + actions.deleteOrphanFiles(table).olderThan(System.currentTimeMillis()).execute(); + assertThat(result3.orphanFileLocations()) + .as("Action should delete 1 file") + .isEqualTo(invalidFiles); + assertThat(fs.exists(new Path(invalidFiles.get(0)))) + .as("Invalid file should not be present") + .isFalse(); + + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(records); + expectedRecords.addAll(records); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = + resultDF.as(Encoders.bean(ThreeColumnRecord.class)).collectAsList(); + assertThat(actualRecords).isEqualTo(expectedRecords); + } + + @TestTemplate + public void testAllValidFilesAreKept() throws IOException { + Table table = TABLES.create(SCHEMA, SPEC, properties, tableLocation); + + List records1 = + Lists.newArrayList(new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA")); + Dataset df1 = spark.createDataFrame(records1, ThreeColumnRecord.class).coalesce(1); + + // original append + df1.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + List records2 = + Lists.newArrayList(new ThreeColumnRecord(2, "AAAAAAAAAA", "AAAA")); + Dataset df2 = spark.createDataFrame(records2, ThreeColumnRecord.class).coalesce(1); + + // dynamic partition overwrite + df2.select("c1", "c2", "c3").write().format("iceberg").mode("overwrite").save(tableLocation); + + // second append + df2.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + List snapshots = Lists.newArrayList(table.snapshots()); + + List snapshotFiles1 = snapshotFiles(snapshots.get(0).snapshotId()); + assertThat(snapshotFiles1).hasSize(1); + + List snapshotFiles2 = snapshotFiles(snapshots.get(1).snapshotId()); + assertThat(snapshotFiles2).hasSize(1); + + List snapshotFiles3 = snapshotFiles(snapshots.get(2).snapshotId()); + assertThat(snapshotFiles3).hasSize(2); + + df2.coalesce(1).write().mode("append").parquet(tableLocation + "/data"); + df2.coalesce(1).write().mode("append").parquet(tableLocation + "/data/c2_trunc=AA"); + df2.coalesce(1).write().mode("append").parquet(tableLocation + "/data/c2_trunc=AA/c3=AAAA"); + df2.coalesce(1).write().mode("append").parquet(tableLocation + "/data/invalid/invalid"); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result = + actions.deleteOrphanFiles(table).olderThan(System.currentTimeMillis()).execute(); + + assertThat(result.orphanFileLocations()).as("Should delete 4 files").hasSize(4); + + Path dataPath = new Path(tableLocation + "/data"); + FileSystem fs = dataPath.getFileSystem(spark.sessionState().newHadoopConf()); + + for (String fileLocation : snapshotFiles1) { + assertThat(fs.exists(new Path(fileLocation))).as("All snapshot files must remain").isTrue(); + } + + for (String fileLocation : snapshotFiles2) { + assertThat(fs.exists(new Path(fileLocation))).as("All snapshot files must remain").isTrue(); + } + + for (String fileLocation : snapshotFiles3) { + assertThat(fs.exists(new Path(fileLocation))).as("All snapshot files must remain").isTrue(); + } + } + + @TestTemplate + public void orphanedFileRemovedWithParallelTasks() { + Table table = TABLES.create(SCHEMA, SPEC, properties, tableLocation); + + List records1 = + Lists.newArrayList(new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA")); + Dataset df1 = spark.createDataFrame(records1, ThreeColumnRecord.class).coalesce(1); + + // original append + df1.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + List records2 = + Lists.newArrayList(new ThreeColumnRecord(2, "AAAAAAAAAA", "AAAA")); + Dataset df2 = spark.createDataFrame(records2, ThreeColumnRecord.class).coalesce(1); + + // dynamic partition overwrite + df2.select("c1", "c2", "c3").write().format("iceberg").mode("overwrite").save(tableLocation); + + // second append + df2.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + df2.coalesce(1).write().mode("append").parquet(tableLocation + "/data"); + df2.coalesce(1).write().mode("append").parquet(tableLocation + "/data/c2_trunc=AA"); + df2.coalesce(1).write().mode("append").parquet(tableLocation + "/data/c2_trunc=AA/c3=AAAA"); + df2.coalesce(1).write().mode("append").parquet(tableLocation + "/data/invalid/invalid"); + + waitUntilAfter(System.currentTimeMillis()); + + Set deletedFiles = ConcurrentHashMap.newKeySet(); + Set deleteThreads = ConcurrentHashMap.newKeySet(); + AtomicInteger deleteThreadsIndex = new AtomicInteger(0); + + ExecutorService executorService = + Executors.newFixedThreadPool( + 4, + runnable -> { + Thread thread = new Thread(runnable); + thread.setName("remove-orphan-" + deleteThreadsIndex.getAndIncrement()); + thread.setDaemon(true); + return thread; + }); + + DeleteOrphanFiles.Result result = + SparkActions.get() + .deleteOrphanFiles(table) + .executeDeleteWith(executorService) + .olderThan(System.currentTimeMillis() + 5000) // Ensure all orphan files are selected + .deleteWith( + file -> { + deleteThreads.add(Thread.currentThread().getName()); + deletedFiles.add(file); + }) + .execute(); + + // Verifies that the delete methods ran in the threads created by the provided ExecutorService + // ThreadFactory + assertThat(deleteThreads) + .isEqualTo( + Sets.newHashSet( + "remove-orphan-0", "remove-orphan-1", "remove-orphan-2", "remove-orphan-3")); + + assertThat(deletedFiles).hasSize(4); + } + + @TestTemplate + public void testWapFilesAreKept() { + assumeThat(formatVersion).as("currently fails with DVs").isEqualTo(2); + Map props = Maps.newHashMap(); + props.put(TableProperties.WRITE_AUDIT_PUBLISH_ENABLED, "true"); + props.putAll(properties); + Table table = TABLES.create(SCHEMA, SPEC, props, tableLocation); + + List records = + Lists.newArrayList(new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA")); + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class); + + // normal write + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + spark.conf().set(SparkSQLProperties.WAP_ID, "1"); + + // wap write + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = + resultDF.as(Encoders.bean(ThreeColumnRecord.class)).collectAsList(); + + // TODO: currently fails because DVs delete stuff from WAP branch + assertThat(actualRecords) + .as("Should not return data from the staged snapshot") + .isEqualTo(records); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result = + actions.deleteOrphanFiles(table).olderThan(System.currentTimeMillis()).execute(); + + assertThat(result.orphanFileLocations()).as("Should not delete any files").isEmpty(); + } + + @TestTemplate + public void testMetadataFolderIsIntact() { + // write data directly to the table location + Map props = Maps.newHashMap(); + props.put(TableProperties.WRITE_DATA_LOCATION, tableLocation); + props.putAll(properties); + Table table = TABLES.create(SCHEMA, SPEC, props, tableLocation); + + List records = + Lists.newArrayList(new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA")); + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class).coalesce(1); + + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + df.write().mode("append").parquet(tableLocation + "/c2_trunc=AA/c3=AAAA"); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result = + actions.deleteOrphanFiles(table).olderThan(System.currentTimeMillis()).execute(); + + assertThat(result.orphanFileLocations()).as("Should delete 1 file").hasSize(1); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = + resultDF.as(Encoders.bean(ThreeColumnRecord.class)).collectAsList(); + assertThat(actualRecords).as("Rows must match").isEqualTo(records); + } + + @TestTemplate + public void testOlderThanTimestamp() { + Table table = TABLES.create(SCHEMA, SPEC, properties, tableLocation); + + List records = + Lists.newArrayList(new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA")); + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class).coalesce(1); + + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + df.write().mode("append").parquet(tableLocation + "/data/c2_trunc=AA/c3=AAAA"); + df.write().mode("append").parquet(tableLocation + "/data/c2_trunc=AA/c3=AAAA"); + + waitUntilAfter(System.currentTimeMillis()); + + long timestamp = System.currentTimeMillis(); + + waitUntilAfter(System.currentTimeMillis() + 1000L); + + df.write().mode("append").parquet(tableLocation + "/data/c2_trunc=AA/c3=AAAA"); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result = + actions.deleteOrphanFiles(table).olderThan(timestamp).execute(); + + assertThat(result.orphanFileLocations()).as("Should delete only 2 files").hasSize(2); + } + + @TestTemplate + public void testRemoveUnreachableMetadataVersionFiles() { + Map props = Maps.newHashMap(); + props.put(TableProperties.WRITE_DATA_LOCATION, tableLocation); + props.put(TableProperties.METADATA_PREVIOUS_VERSIONS_MAX, "1"); + props.putAll(properties); + Table table = TABLES.create(SCHEMA, SPEC, props, tableLocation); + + List records = + Lists.newArrayList(new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA")); + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class); + + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result = + actions.deleteOrphanFiles(table).olderThan(System.currentTimeMillis()).execute(); + + assertThat(result.orphanFileLocations()).as("Should delete 1 file").hasSize(1); + assertThat(StreamSupport.stream(result.orphanFileLocations().spliterator(), false)) + .as("Should remove v1 file") + .anyMatch(file -> file.contains("v1.metadata.json")); + + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(records); + expectedRecords.addAll(records); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = + resultDF.as(Encoders.bean(ThreeColumnRecord.class)).collectAsList(); + assertThat(actualRecords).as("Rows must match").isEqualTo(expectedRecords); + } + + @TestTemplate + public void testManyTopLevelPartitions() { + Table table = TABLES.create(SCHEMA, SPEC, properties, tableLocation); + + List records = Lists.newArrayList(); + for (int i = 0; i < 100; i++) { + records.add(new ThreeColumnRecord(i, String.valueOf(i), String.valueOf(i))); + } + + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class); + + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result = + actions.deleteOrphanFiles(table).olderThan(System.currentTimeMillis()).execute(); + + assertThat(result.orphanFileLocations()).as("Should not delete any files").isEmpty(); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + assertThat(resultDF.count()).as("Rows count must match").isEqualTo(records.size()); + } + + @TestTemplate + public void testManyLeafPartitions() { + Table table = TABLES.create(SCHEMA, SPEC, properties, tableLocation); + + List records = Lists.newArrayList(); + for (int i = 0; i < 100; i++) { + records.add(new ThreeColumnRecord(i, String.valueOf(i % 3), String.valueOf(i))); + } + + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class); + + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result = + actions.deleteOrphanFiles(table).olderThan(System.currentTimeMillis()).execute(); + + assertThat(result.orphanFileLocations()).as("Should not delete any files").isEmpty(); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + assertThat(resultDF.count()).as("Row count must match").isEqualTo(records.size()); + } + + @TestTemplate + public void testHiddenPartitionPaths() { + Schema schema = + new Schema( + optional(1, "c1", Types.IntegerType.get()), + optional(2, "_c2", Types.StringType.get()), + optional(3, "c3", Types.StringType.get())); + PartitionSpec spec = PartitionSpec.builderFor(schema).truncate("_c2", 2).identity("c3").build(); + Table table = TABLES.create(schema, spec, properties, tableLocation); + + StructType structType = + new StructType() + .add("c1", DataTypes.IntegerType) + .add("_c2", DataTypes.StringType) + .add("c3", DataTypes.StringType); + List records = Lists.newArrayList(RowFactory.create(1, "AAAAAAAAAA", "AAAA")); + Dataset df = spark.createDataFrame(records, structType).coalesce(1); + + df.select("c1", "_c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + df.write().mode("append").parquet(tableLocation + "/data/_c2_trunc=AA/c3=AAAA"); + df.write().mode("append").parquet(tableLocation + "/data/_c2_trunc=AA/c3=AAAA"); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result = + actions.deleteOrphanFiles(table).olderThan(System.currentTimeMillis()).execute(); + + assertThat(result.orphanFileLocations()).as("Should delete 2 files").hasSize(2); + } + + @TestTemplate + public void testHiddenPartitionPathsWithPartitionEvolution() { + Schema schema = + new Schema( + optional(1, "_c1", Types.IntegerType.get()), + optional(2, "_c2", Types.StringType.get()), + optional(3, "c3", Types.StringType.get())); + PartitionSpec spec = PartitionSpec.builderFor(schema).truncate("_c2", 2).build(); + Table table = TABLES.create(schema, spec, properties, tableLocation); + + StructType structType = + new StructType() + .add("_c1", DataTypes.IntegerType) + .add("_c2", DataTypes.StringType) + .add("c3", DataTypes.StringType); + List records = Lists.newArrayList(RowFactory.create(1, "AAAAAAAAAA", "AAAA")); + Dataset df = spark.createDataFrame(records, structType).coalesce(1); + + df.select("_c1", "_c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + df.write().mode("append").parquet(tableLocation + "/data/_c2_trunc=AA"); + + table.updateSpec().addField("_c1").commit(); + + df.write().mode("append").parquet(tableLocation + "/data/_c2_trunc=AA/_c1=1"); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result = + actions.deleteOrphanFiles(table).olderThan(System.currentTimeMillis()).execute(); + + assertThat(result.orphanFileLocations()).as("Should delete 2 files").hasSize(2); + } + + @TestTemplate + public void testHiddenPathsStartingWithPartitionNamesAreIgnored() throws IOException { + Schema schema = + new Schema( + optional(1, "c1", Types.IntegerType.get()), + optional(2, "_c2", Types.StringType.get()), + optional(3, "c3", Types.StringType.get())); + PartitionSpec spec = PartitionSpec.builderFor(schema).truncate("_c2", 2).identity("c3").build(); + Table table = TABLES.create(schema, spec, properties, tableLocation); + + StructType structType = + new StructType() + .add("c1", DataTypes.IntegerType) + .add("_c2", DataTypes.StringType) + .add("c3", DataTypes.StringType); + List records = Lists.newArrayList(RowFactory.create(1, "AAAAAAAAAA", "AAAA")); + Dataset df = spark.createDataFrame(records, structType).coalesce(1); + + df.select("c1", "_c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + Path dataPath = new Path(tableLocation + "/data"); + FileSystem fs = dataPath.getFileSystem(spark.sessionState().newHadoopConf()); + Path pathToFileInHiddenFolder = new Path(dataPath, "_c2_trunc/file.txt"); + fs.createNewFile(pathToFileInHiddenFolder); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result = + actions.deleteOrphanFiles(table).olderThan(System.currentTimeMillis()).execute(); + + assertThat(result.orphanFileLocations()).as("Should delete 0 files").isEmpty(); + assertThat(fs.exists(pathToFileInHiddenFolder)).isTrue(); + } + + private List snapshotFiles(long snapshotId) { + return spark + .read() + .format("iceberg") + .option("snapshot-id", snapshotId) + .load(tableLocation + "#files") + .select("file_path") + .as(Encoders.STRING()) + .collectAsList(); + } + + @TestTemplate + public void testRemoveOrphanFilesWithRelativeFilePath() throws IOException { + Table table = + TABLES.create( + SCHEMA, PartitionSpec.unpartitioned(), properties, tableDir.getAbsolutePath()); + + List records = + Lists.newArrayList(new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA")); + + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class).coalesce(1); + + df.select("c1", "c2", "c3") + .write() + .format("iceberg") + .mode("append") + .save(tableDir.getAbsolutePath()); + + List validFiles = + spark + .read() + .format("iceberg") + .load(tableLocation + "#files") + .select("file_path") + .as(Encoders.STRING()) + .collectAsList(); + assertThat(validFiles).as("Should be 1 valid file").hasSize(1); + String validFile = validFiles.get(0); + + df.write().mode("append").parquet(tableLocation + "/data"); + + Path dataPath = new Path(tableLocation + "/data"); + FileSystem fs = dataPath.getFileSystem(spark.sessionState().newHadoopConf()); + List allFiles = + Arrays.stream(fs.listStatus(dataPath, HiddenPathFilter.get())) + .filter(FileStatus::isFile) + .map(file -> file.getPath().toString()) + .collect(Collectors.toList()); + assertThat(allFiles).as("Should be 2 files").hasSize(2); + + List invalidFiles = Lists.newArrayList(allFiles); + invalidFiles.removeIf(file -> file.contains(validFile)); + assertThat(invalidFiles).as("Should be 1 invalid file").hasSize(1); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + DeleteOrphanFiles.Result result = + actions + .deleteOrphanFiles(table) + .olderThan(System.currentTimeMillis()) + .deleteWith(s -> {}) + .execute(); + assertThat(result.orphanFileLocations()) + .as("Action should find 1 file") + .isEqualTo(invalidFiles); + assertThat(fs.exists(new Path(invalidFiles.get(0)))) + .as("Invalid file should be present") + .isTrue(); + } + + @TestTemplate + public void testRemoveOrphanFilesWithHadoopCatalog() throws InterruptedException { + HadoopCatalog catalog = new HadoopCatalog(new Configuration(), tableLocation); + String namespaceName = "testDb"; + String tableName = "testTb"; + + Namespace namespace = Namespace.of(namespaceName); + TableIdentifier tableIdentifier = TableIdentifier.of(namespace, tableName); + Table table = + catalog.createTable( + tableIdentifier, SCHEMA, PartitionSpec.unpartitioned(), Maps.newHashMap()); + + List records = + Lists.newArrayList(new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA")); + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class).coalesce(1); + + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(table.location()); + + df.write().mode("append").parquet(table.location() + "/data"); + + waitUntilAfter(System.currentTimeMillis()); + + table.refresh(); + + DeleteOrphanFiles.Result result = + SparkActions.get().deleteOrphanFiles(table).olderThan(System.currentTimeMillis()).execute(); + + assertThat(result.orphanFileLocations()).as("Should delete only 1 file").hasSize(1); + + Dataset resultDF = spark.read().format("iceberg").load(table.location()); + List actualRecords = + resultDF.as(Encoders.bean(ThreeColumnRecord.class)).collectAsList(); + assertThat(actualRecords).as("Rows must match").isEqualTo(records); + } + + @TestTemplate + public void testHiveCatalogTable() throws IOException { + TableIdentifier identifier = + TableIdentifier.of("default", "hivetestorphan" + ThreadLocalRandom.current().nextInt(1000)); + Table table = catalog.createTable(identifier, SCHEMA, SPEC, tableLocation, properties); + + List records = + Lists.newArrayList(new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA")); + + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class).coalesce(1); + + df.select("c1", "c2", "c3") + .write() + .format("iceberg") + .mode("append") + .save(identifier.toString()); + + String location = table.location().replaceFirst("file:", ""); + new File(location + "/data/trashfile").createNewFile(); + + DeleteOrphanFiles.Result result = + SparkActions.get() + .deleteOrphanFiles(table) + .olderThan(System.currentTimeMillis() + 1000) + .execute(); + assertThat(StreamSupport.stream(result.orphanFileLocations().spliterator(), false)) + .as("trash file should be removed") + .anyMatch(file -> file.contains("file:" + location + "/data/trashfile")); + } + + @TestTemplate + public void testGarbageCollectionDisabled() { + Table table = TABLES.create(SCHEMA, PartitionSpec.unpartitioned(), properties, tableLocation); + + List records = + Lists.newArrayList(new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA")); + + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class).coalesce(1); + + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + table.updateProperties().set(TableProperties.GC_ENABLED, "false").commit(); + + assertThatThrownBy(() -> SparkActions.get().deleteOrphanFiles(table).execute()) + .isInstanceOf(ValidationException.class) + .hasMessage( + "Cannot delete orphan files: GC is disabled (deleting files may corrupt other tables)"); + } + + @TestTemplate + public void testCompareToFileList() throws IOException { + Table table = TABLES.create(SCHEMA, PartitionSpec.unpartitioned(), properties, tableLocation); + + List records = + Lists.newArrayList(new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA")); + + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class).coalesce(1); + + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + Path dataPath = new Path(tableLocation + "/data"); + FileSystem fs = dataPath.getFileSystem(spark.sessionState().newHadoopConf()); + List validFiles = + Arrays.stream(fs.listStatus(dataPath, HiddenPathFilter.get())) + .filter(FileStatus::isFile) + .map( + file -> + new FilePathLastModifiedRecord( + file.getPath().toString(), new Timestamp(file.getModificationTime()))) + .collect(Collectors.toList()); + + assertThat(validFiles).as("Should be 2 valid files").hasSize(2); + + df.write().mode("append").parquet(tableLocation + "/data"); + + List allFiles = + Arrays.stream(fs.listStatus(dataPath, HiddenPathFilter.get())) + .filter(FileStatus::isFile) + .map( + file -> + new FilePathLastModifiedRecord( + file.getPath().toString(), new Timestamp(file.getModificationTime()))) + .collect(Collectors.toList()); + + assertThat(allFiles).as("Should be 3 files").hasSize(3); + + List invalidFiles = Lists.newArrayList(allFiles); + invalidFiles.removeAll(validFiles); + List invalidFilePaths = + invalidFiles.stream() + .map(FilePathLastModifiedRecord::getFilePath) + .collect(Collectors.toList()); + assertThat(invalidFiles).as("Should be 1 invalid file").hasSize(1); + + // sleep for 1 second to ensure files will be old enough + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + Dataset compareToFileList = + spark + .createDataFrame(allFiles, FilePathLastModifiedRecord.class) + .withColumnRenamed("filePath", "file_path") + .withColumnRenamed("lastModified", "last_modified"); + + DeleteOrphanFiles.Result result1 = + actions + .deleteOrphanFiles(table) + .compareToFileList(compareToFileList) + .deleteWith(s -> {}) + .execute(); + assertThat(result1.orphanFileLocations()) + .as("Default olderThan interval should be safe") + .isEmpty(); + + DeleteOrphanFiles.Result result2 = + actions + .deleteOrphanFiles(table) + .compareToFileList(compareToFileList) + .olderThan(System.currentTimeMillis()) + .deleteWith(s -> {}) + .execute(); + assertThat(result2.orphanFileLocations()) + .as("Action should find 1 file") + .isEqualTo(invalidFilePaths); + assertThat(fs.exists(new Path(invalidFilePaths.get(0)))) + .as("Invalid file should be present") + .isTrue(); + + DeleteOrphanFiles.Result result3 = + actions + .deleteOrphanFiles(table) + .compareToFileList(compareToFileList) + .olderThan(System.currentTimeMillis()) + .execute(); + assertThat(result3.orphanFileLocations()) + .as("Action should delete 1 file") + .isEqualTo(invalidFilePaths); + assertThat(fs.exists(new Path(invalidFilePaths.get(0)))) + .as("Invalid file should not be present") + .isFalse(); + + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(records); + expectedRecords.addAll(records); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = + resultDF.as(Encoders.bean(ThreeColumnRecord.class)).collectAsList(); + assertThat(actualRecords).as("Rows must match").isEqualTo(expectedRecords); + + List outsideLocationMockFiles = + Lists.newArrayList(new FilePathLastModifiedRecord("/tmp/mock1", new Timestamp(0L))); + + Dataset compareToFileListWithOutsideLocation = + spark + .createDataFrame(outsideLocationMockFiles, FilePathLastModifiedRecord.class) + .withColumnRenamed("filePath", "file_path") + .withColumnRenamed("lastModified", "last_modified"); + + DeleteOrphanFiles.Result result4 = + actions + .deleteOrphanFiles(table) + .compareToFileList(compareToFileListWithOutsideLocation) + .deleteWith(s -> {}) + .execute(); + assertThat(result4.orphanFileLocations()).as("Action should find nothing").isEmpty(); + } + + protected long waitUntilAfter(long timestampMillis) { + long current = System.currentTimeMillis(); + while (current <= timestampMillis) { + current = System.currentTimeMillis(); + } + return current; + } + + @TestTemplate + public void testRemoveOrphanFilesWithStatisticFiles() throws Exception { + assumeThat(formatVersion).isGreaterThanOrEqualTo(2); + Table table = TABLES.create(SCHEMA, PartitionSpec.unpartitioned(), properties, tableLocation); + + List records = + Lists.newArrayList(new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA")); + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class).coalesce(1); + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + table.refresh(); + long snapshotId = table.currentSnapshot().snapshotId(); + long snapshotSequenceNumber = table.currentSnapshot().sequenceNumber(); + + File statsLocation = + new File(new URI(tableLocation)) + .toPath() + .resolve("data") + .resolve("some-stats-file") + .toFile(); + StatisticsFile statisticsFile; + try (PuffinWriter puffinWriter = Puffin.write(Files.localOutput(statsLocation)).build()) { + puffinWriter.add( + new Blob( + "some-blob-type", + ImmutableList.of(1), + snapshotId, + snapshotSequenceNumber, + ByteBuffer.wrap("blob content".getBytes(StandardCharsets.UTF_8)))); + puffinWriter.finish(); + statisticsFile = + new GenericStatisticsFile( + snapshotId, + statsLocation.toString(), + puffinWriter.fileSize(), + puffinWriter.footerSize(), + puffinWriter.writtenBlobsMetadata().stream() + .map(GenericBlobMetadata::from) + .collect(ImmutableList.toImmutableList())); + } + + Transaction transaction = table.newTransaction(); + transaction.updateStatistics().setStatistics(snapshotId, statisticsFile).commit(); + transaction.commitTransaction(); + + SparkActions.get() + .deleteOrphanFiles(table) + .olderThan(System.currentTimeMillis() + 1000) + .execute(); + + assertThat(statsLocation).as("stats file should exist").exists(); + assertThat(statsLocation.length()) + .as("stats file length") + .isEqualTo(statisticsFile.fileSizeInBytes()); + + transaction = table.newTransaction(); + transaction.updateStatistics().removeStatistics(statisticsFile.snapshotId()).commit(); + transaction.commitTransaction(); + + DeleteOrphanFiles.Result result = + SparkActions.get() + .deleteOrphanFiles(table) + .olderThan(System.currentTimeMillis() + 1000) + .execute(); + Iterable orphanFileLocations = result.orphanFileLocations(); + assertThat(orphanFileLocations).as("Should be orphan file").hasSize(1); + assertThat(Iterables.getOnlyElement(orphanFileLocations)) + .as("Deleted file") + .isEqualTo(statsLocation.toURI().toString()); + assertThat(statsLocation.exists()).as("stats file should be deleted").isFalse(); + } + + @TestTemplate + public void testPathsWithExtraSlashes() { + List validFiles = Lists.newArrayList("file:///dir1/dir2/file1"); + List actualFiles = Lists.newArrayList("file:///dir1/////dir2///file1"); + executeTest(validFiles, actualFiles, Lists.newArrayList()); + } + + @TestTemplate + public void testPathsWithValidFileHavingNoAuthority() { + List validFiles = Lists.newArrayList("hdfs:///dir1/dir2/file1"); + List actualFiles = Lists.newArrayList("hdfs://servicename/dir1/dir2/file1"); + executeTest(validFiles, actualFiles, Lists.newArrayList()); + } + + @TestTemplate + public void testPathsWithActualFileHavingNoAuthority() { + List validFiles = Lists.newArrayList("hdfs://servicename/dir1/dir2/file1"); + List actualFiles = Lists.newArrayList("hdfs:///dir1/dir2/file1"); + executeTest(validFiles, actualFiles, Lists.newArrayList()); + } + + @TestTemplate + public void testPathsWithEqualSchemes() { + List validFiles = Lists.newArrayList("scheme1://bucket1/dir1/dir2/file1"); + List actualFiles = Lists.newArrayList("scheme2://bucket1/dir1/dir2/file1"); + assertThatThrownBy( + () -> + executeTest( + validFiles, + actualFiles, + Lists.newArrayList(), + ImmutableMap.of(), + ImmutableMap.of(), + DeleteOrphanFiles.PrefixMismatchMode.ERROR)) + .isInstanceOf(ValidationException.class) + .hasMessageStartingWith("Unable to determine whether certain files are orphan") + .hasMessageEndingWith("Conflicting authorities/schemes: [(scheme1, scheme2)]."); + + Map equalSchemes = Maps.newHashMap(); + equalSchemes.put("scheme1", "scheme"); + equalSchemes.put("scheme2", "scheme"); + executeTest( + validFiles, + actualFiles, + Lists.newArrayList(), + equalSchemes, + ImmutableMap.of(), + DeleteOrphanFiles.PrefixMismatchMode.ERROR); + } + + @TestTemplate + public void testPathsWithEqualAuthorities() { + List validFiles = Lists.newArrayList("hdfs://servicename1/dir1/dir2/file1"); + List actualFiles = Lists.newArrayList("hdfs://servicename2/dir1/dir2/file1"); + assertThatThrownBy( + () -> + executeTest( + validFiles, + actualFiles, + Lists.newArrayList(), + ImmutableMap.of(), + ImmutableMap.of(), + DeleteOrphanFiles.PrefixMismatchMode.ERROR)) + .isInstanceOf(ValidationException.class) + .hasMessageStartingWith("Unable to determine whether certain files are orphan") + .hasMessageEndingWith("Conflicting authorities/schemes: [(servicename1, servicename2)]."); + + Map equalAuthorities = Maps.newHashMap(); + equalAuthorities.put("servicename1", "servicename"); + equalAuthorities.put("servicename2", "servicename"); + executeTest( + validFiles, + actualFiles, + Lists.newArrayList(), + ImmutableMap.of(), + equalAuthorities, + DeleteOrphanFiles.PrefixMismatchMode.ERROR); + } + + @TestTemplate + public void testRemoveOrphanFileActionWithDeleteMode() { + List validFiles = Lists.newArrayList("hdfs://servicename1/dir1/dir2/file1"); + List actualFiles = Lists.newArrayList("hdfs://servicename2/dir1/dir2/file1"); + + executeTest( + validFiles, + actualFiles, + Lists.newArrayList("hdfs://servicename2/dir1/dir2/file1"), + ImmutableMap.of(), + ImmutableMap.of(), + DeleteOrphanFiles.PrefixMismatchMode.DELETE); + } + + private void executeTest( + List validFiles, List actualFiles, List expectedOrphanFiles) { + executeTest( + validFiles, + actualFiles, + expectedOrphanFiles, + ImmutableMap.of(), + ImmutableMap.of(), + DeleteOrphanFiles.PrefixMismatchMode.IGNORE); + } + + private void executeTest( + List validFiles, + List actualFiles, + List expectedOrphanFiles, + Map equalSchemes, + Map equalAuthorities, + DeleteOrphanFiles.PrefixMismatchMode mode) { + + StringToFileURI toFileUri = new StringToFileURI(equalSchemes, equalAuthorities); + + Dataset validFileDS = spark.createDataset(validFiles, Encoders.STRING()); + Dataset actualFileDS = spark.createDataset(actualFiles, Encoders.STRING()); + + List orphanFiles = + DeleteOrphanFilesSparkAction.findOrphanFiles( + spark, toFileUri.apply(actualFileDS), toFileUri.apply(validFileDS), mode); + assertThat(orphanFiles).isEqualTo(expectedOrphanFiles); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestRemoveOrphanFilesAction3.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestRemoveOrphanFilesAction3.java new file mode 100644 index 000000000000..14784da4f74f --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestRemoveOrphanFilesAction3.java @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.File; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.StreamSupport; +import org.apache.iceberg.actions.DeleteOrphanFiles; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.iceberg.spark.source.SparkTable; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.expressions.Transform; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestRemoveOrphanFilesAction3 extends TestRemoveOrphanFilesAction { + @TestTemplate + public void testSparkCatalogTable() throws Exception { + spark.conf().set("spark.sql.catalog.mycat", "org.apache.iceberg.spark.SparkCatalog"); + spark.conf().set("spark.sql.catalog.mycat.type", "hadoop"); + spark.conf().set("spark.sql.catalog.mycat.warehouse", tableLocation); + SparkCatalog cat = (SparkCatalog) spark.sessionState().catalogManager().catalog("mycat"); + + String[] database = {"default"}; + Identifier id = Identifier.of(database, "table" + ThreadLocalRandom.current().nextInt(1000)); + Transform[] transforms = {}; + cat.createTable(id, SparkSchemaUtil.convert(SCHEMA), transforms, properties); + SparkTable table = (SparkTable) cat.loadTable(id); + + sql("INSERT INTO mycat.default.%s VALUES (1,1,1)", id.name()); + + String location = table.table().location().replaceFirst("file:", ""); + String trashFile = "/data/trashfile" + ThreadLocalRandom.current().nextInt(1000); + new File(location + trashFile).createNewFile(); + + DeleteOrphanFiles.Result results = + SparkActions.get() + .deleteOrphanFiles(table.table()) + .olderThan(System.currentTimeMillis() + 1000) + .execute(); + assertThat(StreamSupport.stream(results.orphanFileLocations().spliterator(), false)) + .as("trash file should be removed") + .anyMatch(file -> file.contains("file:" + location + trashFile)); + } + + @TestTemplate + public void testSparkCatalogNamedHadoopTable() throws Exception { + spark.conf().set("spark.sql.catalog.hadoop", "org.apache.iceberg.spark.SparkCatalog"); + spark.conf().set("spark.sql.catalog.hadoop.type", "hadoop"); + spark.conf().set("spark.sql.catalog.hadoop.warehouse", tableLocation); + SparkCatalog cat = (SparkCatalog) spark.sessionState().catalogManager().catalog("hadoop"); + + String[] database = {"default"}; + Identifier id = Identifier.of(database, "table" + ThreadLocalRandom.current().nextInt(1000)); + Transform[] transforms = {}; + cat.createTable(id, SparkSchemaUtil.convert(SCHEMA), transforms, properties); + SparkTable table = (SparkTable) cat.loadTable(id); + + sql("INSERT INTO hadoop.default.%s VALUES (1,1,1)", id.name()); + + String location = table.table().location().replaceFirst("file:", ""); + String trashFile = "/data/trashfile" + ThreadLocalRandom.current().nextInt(1000); + new File(location + trashFile).createNewFile(); + + DeleteOrphanFiles.Result results = + SparkActions.get() + .deleteOrphanFiles(table.table()) + .olderThan(System.currentTimeMillis() + 1000) + .execute(); + assertThat(StreamSupport.stream(results.orphanFileLocations().spliterator(), false)) + .as("trash file should be removed") + .anyMatch(file -> file.contains("file:" + location + trashFile)); + } + + @TestTemplate + public void testSparkCatalogNamedHiveTable() throws Exception { + spark.conf().set("spark.sql.catalog.hive", "org.apache.iceberg.spark.SparkCatalog"); + spark.conf().set("spark.sql.catalog.hive.type", "hadoop"); + spark.conf().set("spark.sql.catalog.hive.warehouse", tableLocation); + SparkCatalog cat = (SparkCatalog) spark.sessionState().catalogManager().catalog("hive"); + + String[] database = {"default"}; + Identifier id = Identifier.of(database, "table" + ThreadLocalRandom.current().nextInt(1000)); + Transform[] transforms = {}; + cat.createTable(id, SparkSchemaUtil.convert(SCHEMA), transforms, properties); + SparkTable table = (SparkTable) cat.loadTable(id); + + sql("INSERT INTO hive.default.%s VALUES (1,1,1)", id.name()); + + String location = table.table().location().replaceFirst("file:", ""); + String trashFile = "/data/trashfile" + ThreadLocalRandom.current().nextInt(1000); + new File(location + trashFile).createNewFile(); + + DeleteOrphanFiles.Result results = + SparkActions.get() + .deleteOrphanFiles(table.table()) + .olderThan(System.currentTimeMillis() + 1000) + .execute(); + + assertThat(StreamSupport.stream(results.orphanFileLocations().spliterator(), false)) + .as("trash file should be removed") + .anyMatch(file -> file.contains("file:" + location + trashFile)); + } + + @TestTemplate + public void testSparkSessionCatalogHadoopTable() throws Exception { + spark + .conf() + .set("spark.sql.catalog.spark_catalog", "org.apache.iceberg.spark.SparkSessionCatalog"); + spark.conf().set("spark.sql.catalog.spark_catalog.type", "hadoop"); + spark.conf().set("spark.sql.catalog.spark_catalog.warehouse", tableLocation); + SparkSessionCatalog cat = + (SparkSessionCatalog) spark.sessionState().catalogManager().v2SessionCatalog(); + + String[] database = {"default"}; + Identifier id = Identifier.of(database, "table" + ThreadLocalRandom.current().nextInt(1000)); + Transform[] transforms = {}; + cat.createTable(id, SparkSchemaUtil.convert(SCHEMA), transforms, properties); + SparkTable table = (SparkTable) cat.loadTable(id); + + sql("INSERT INTO default.%s VALUES (1,1,1)", id.name()); + + String location = table.table().location().replaceFirst("file:", ""); + String trashFile = "/data/trashfile" + ThreadLocalRandom.current().nextInt(1000); + new File(location + trashFile).createNewFile(); + + DeleteOrphanFiles.Result results = + SparkActions.get() + .deleteOrphanFiles(table.table()) + .olderThan(System.currentTimeMillis() + 1000) + .execute(); + assertThat(StreamSupport.stream(results.orphanFileLocations().spliterator(), false)) + .as("trash file should be removed") + .anyMatch(file -> file.contains("file:" + location + trashFile)); + } + + @TestTemplate + public void testSparkSessionCatalogHiveTable() throws Exception { + spark + .conf() + .set("spark.sql.catalog.spark_catalog", "org.apache.iceberg.spark.SparkSessionCatalog"); + spark.conf().set("spark.sql.catalog.spark_catalog.type", "hive"); + SparkSessionCatalog cat = + (SparkSessionCatalog) spark.sessionState().catalogManager().v2SessionCatalog(); + + String[] database = {"default"}; + Identifier id = Identifier.of(database, "sessioncattest"); + Transform[] transforms = {}; + cat.dropTable(id); + cat.createTable(id, SparkSchemaUtil.convert(SCHEMA), transforms, properties); + SparkTable table = (SparkTable) cat.loadTable(id); + + spark.sql("INSERT INTO default.sessioncattest VALUES (1,1,1)"); + + String location = table.table().location().replaceFirst("file:", ""); + String trashFile = "/data/trashfile" + ThreadLocalRandom.current().nextInt(1000); + new File(location + trashFile).createNewFile(); + + DeleteOrphanFiles.Result results = + SparkActions.get() + .deleteOrphanFiles(table.table()) + .olderThan(System.currentTimeMillis() + 1000) + .execute(); + assertThat(StreamSupport.stream(results.orphanFileLocations().spliterator(), false)) + .as("trash file should be removed") + .anyMatch(file -> file.contains("file:" + location + trashFile)); + } + + @AfterEach + public void resetSparkSessionCatalog() throws Exception { + spark.conf().unset("spark.sql.catalog.spark_catalog"); + spark.conf().unset("spark.sql.catalog.spark_catalog.type"); + spark.conf().unset("spark.sql.catalog.spark_catalog.warehouse"); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestRewriteDataFilesAction.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestRewriteDataFilesAction.java new file mode 100644 index 000000000000..6d722ace8af2 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestRewriteDataFilesAction.java @@ -0,0 +1,2312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.apache.iceberg.TableProperties.COMMIT_NUM_RETRIES; +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.apache.spark.sql.functions.current_date; +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.expr; +import static org.apache.spark.sql.functions.min; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assumptions.assumeThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doCallRealMethod; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.spy; + +import java.io.File; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.MetadataTableType; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.PartitionData; +import org.apache.iceberg.PartitionKey; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.RewriteJobOrder; +import org.apache.iceberg.RowDelta; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.actions.RewriteDataFiles; +import org.apache.iceberg.actions.RewriteDataFiles.Result; +import org.apache.iceberg.actions.RewriteDataFilesCommitManager; +import org.apache.iceberg.actions.RewriteFileGroup; +import org.apache.iceberg.actions.SizeBasedDataRewriter; +import org.apache.iceberg.actions.SizeBasedFileRewriter; +import org.apache.iceberg.data.GenericAppenderFactory; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.deletes.BaseDVFileWriter; +import org.apache.iceberg.deletes.DVFileWriter; +import org.apache.iceberg.deletes.EqualityDeleteWriter; +import org.apache.iceberg.deletes.PositionDelete; +import org.apache.iceberg.deletes.PositionDeleteWriter; +import org.apache.iceberg.encryption.EncryptedFiles; +import org.apache.iceberg.encryption.EncryptedOutputFile; +import org.apache.iceberg.encryption.EncryptionKeyMetadata; +import org.apache.iceberg.exceptions.CommitFailedException; +import org.apache.iceberg.exceptions.CommitStateUnknownException; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.io.OutputFileFactory; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Streams; +import org.apache.iceberg.spark.FileRewriteCoordinator; +import org.apache.iceberg.spark.ScanTaskSetManager; +import org.apache.iceberg.spark.SparkTableUtil; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.spark.TestBase; +import org.apache.iceberg.spark.actions.RewriteDataFilesSparkAction.RewriteExecutionContext; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.spark.source.ThreeColumnRecord; +import org.apache.iceberg.types.Comparators; +import org.apache.iceberg.types.Conversions; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.types.Types.NestedField; +import org.apache.iceberg.util.ArrayUtil; +import org.apache.iceberg.util.Pair; +import org.apache.iceberg.util.StructLikeMap; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.internal.SQLConf; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; +import org.mockito.ArgumentMatcher; +import org.mockito.Mockito; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestRewriteDataFilesAction extends TestBase { + + @TempDir private File tableDir; + private static final int SCALE = 400000; + + private static final HadoopTables TABLES = new HadoopTables(new Configuration()); + private static final Schema SCHEMA = + new Schema( + optional(1, "c1", Types.IntegerType.get()), + optional(2, "c2", Types.StringType.get()), + optional(3, "c3", Types.StringType.get())); + + private static final PartitionSpec SPEC = PartitionSpec.builderFor(SCHEMA).identity("c1").build(); + + @Parameter private int formatVersion; + + @Parameters(name = "formatVersion = {0}") + protected static List parameters() { + return Arrays.asList(2, 3); + } + + private final FileRewriteCoordinator coordinator = FileRewriteCoordinator.get(); + private final ScanTaskSetManager manager = ScanTaskSetManager.get(); + private String tableLocation = null; + + @BeforeAll + public static void setupSpark() { + // disable AQE as tests assume that writes generate a particular number of files + spark.conf().set(SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), "false"); + } + + @BeforeEach + public void setupTableLocation() throws Exception { + this.tableLocation = tableDir.toURI().toString(); + } + + private RewriteDataFilesSparkAction basicRewrite(Table table) { + // Always compact regardless of input files + table.refresh(); + return actions().rewriteDataFiles(table).option(SizeBasedFileRewriter.MIN_INPUT_FILES, "1"); + } + + @TestTemplate + public void testEmptyTable() { + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = + ImmutableMap.of(TableProperties.FORMAT_VERSION, String.valueOf(formatVersion)); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + assertThat(table.currentSnapshot()).as("Table must be empty").isNull(); + + basicRewrite(table).execute(); + + assertThat(table.currentSnapshot()).as("Table must stay empty").isNull(); + } + + @TestTemplate + public void testBinPackUnpartitionedTable() { + Table table = createTable(4); + shouldHaveFiles(table, 4); + List expectedRecords = currentData(); + long dataSizeBefore = testDataSize(table); + + Result result = basicRewrite(table).execute(); + assertThat(result.rewrittenDataFilesCount()) + .as("Action should rewrite 4 data files") + .isEqualTo(4); + assertThat(result.addedDataFilesCount()).as("Action should add 1 data file").isOne(); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + shouldHaveFiles(table, 1); + List actual = currentData(); + + assertEquals("Rows must match", expectedRecords, actual); + } + + @TestTemplate + public void testBinPackPartitionedTable() { + Table table = createTablePartitioned(4, 2); + shouldHaveFiles(table, 8); + List expectedRecords = currentData(); + long dataSizeBefore = testDataSize(table); + + Result result = basicRewrite(table).execute(); + assertThat(result.rewrittenDataFilesCount()) + .as("Action should rewrite 8 data files") + .isEqualTo(8); + assertThat(result.addedDataFilesCount()).as("Action should add 4 data file").isEqualTo(4); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + shouldHaveFiles(table, 4); + List actualRecords = currentData(); + + assertEquals("Rows must match", expectedRecords, actualRecords); + } + + @TestTemplate + public void testBinPackWithFilter() { + Table table = createTablePartitioned(4, 2); + shouldHaveFiles(table, 8); + List expectedRecords = currentData(); + long dataSizeBefore = testDataSize(table); + + Result result = + basicRewrite(table) + .filter(Expressions.equal("c1", 1)) + .filter(Expressions.startsWith("c2", "foo")) + .execute(); + + assertThat(result.rewrittenDataFilesCount()) + .as("Action should rewrite 2 data files") + .isEqualTo(2); + assertThat(result.addedDataFilesCount()).as("Action should add 1 data file").isOne(); + assertThat(result.rewrittenBytesCount()).isGreaterThan(0L).isLessThan(dataSizeBefore); + + shouldHaveFiles(table, 7); + + List actualRecords = currentData(); + assertEquals("Rows must match", expectedRecords, actualRecords); + } + + @TestTemplate + public void testBinPackWithFilterOnBucketExpression() { + Table table = createTablePartitioned(4, 2); + + shouldHaveFiles(table, 8); + List expectedRecords = currentData(); + long dataSizeBefore = testDataSize(table); + + Result result = + basicRewrite(table) + .filter(Expressions.equal("c1", 1)) + .filter(Expressions.equal(Expressions.bucket("c2", 2), 0)) + .execute(); + + assertThat(result) + .extracting(Result::rewrittenDataFilesCount, Result::addedDataFilesCount) + .as("Action should rewrite 2 data files into 1 data file") + .contains(2, 1); + assertThat(result.rewrittenBytesCount()).isGreaterThan(0L).isLessThan(dataSizeBefore); + + shouldHaveFiles(table, 7); + + List actualRecords = currentData(); + assertEquals("Rows must match", expectedRecords, actualRecords); + } + + @TestTemplate + public void testBinPackAfterPartitionChange() { + Table table = createTable(); + + writeRecords(20, SCALE, 20); + shouldHaveFiles(table, 20); + table.updateSpec().addField(Expressions.ref("c1")).commit(); + + List originalData = currentData(); + long dataSizeBefore = testDataSize(table); + + RewriteDataFiles.Result result = + basicRewrite(table) + .option(SizeBasedFileRewriter.MIN_INPUT_FILES, "1") + .option( + SizeBasedFileRewriter.MIN_FILE_SIZE_BYTES, + Integer.toString(averageFileSize(table) + 1000)) + .option( + RewriteDataFiles.TARGET_FILE_SIZE_BYTES, + Integer.toString(averageFileSize(table) + 1001)) + .execute(); + + assertThat(result.rewriteResults()) + .as("Should have 1 fileGroup because all files were not correctly partitioned") + .hasSize(1); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); + shouldHaveACleanCache(table); + shouldHaveFiles(table, 20); + } + + @TestTemplate + public void testBinPackWithDeletes() throws IOException { + assumeThat(formatVersion).isGreaterThanOrEqualTo(2); + Table table = createTablePartitioned(4, 2); + shouldHaveFiles(table, 8); + table.refresh(); + + List dataFiles = TestHelpers.dataFiles(table); + int total = (int) dataFiles.stream().mapToLong(ContentFile::recordCount).sum(); + + RowDelta rowDelta = table.newRowDelta(); + if (formatVersion >= 3) { + // delete 1 position for data files 0, 1, 2 + for (int i = 0; i < 3; i++) { + writeDV(table, dataFiles.get(i).partition(), dataFiles.get(i).location(), 1) + .forEach(rowDelta::addDeletes); + } + + // delete 2 positions for data files 3, 4 + for (int i = 3; i < 5; i++) { + writeDV(table, dataFiles.get(i).partition(), dataFiles.get(i).location(), 2) + .forEach(rowDelta::addDeletes); + } + } else { + // add 1 delete file for data files 0, 1, 2 + for (int i = 0; i < 3; i++) { + writePosDeletesToFile(table, dataFiles.get(i), 1).forEach(rowDelta::addDeletes); + } + + // add 2 delete files for data files 3, 4 + for (int i = 3; i < 5; i++) { + writePosDeletesToFile(table, dataFiles.get(i), 2).forEach(rowDelta::addDeletes); + } + } + + rowDelta.commit(); + table.refresh(); + List expectedRecords = currentData(); + long dataSizeBefore = testDataSize(table); + + if (formatVersion >= 3) { + Result result = + actions() + .rewriteDataFiles(table) + // do not include any file based on bin pack file size configs + .option(SizeBasedFileRewriter.MIN_FILE_SIZE_BYTES, "0") + .option(RewriteDataFiles.TARGET_FILE_SIZE_BYTES, Long.toString(Long.MAX_VALUE - 1)) + .option(SizeBasedFileRewriter.MAX_FILE_SIZE_BYTES, Long.toString(Long.MAX_VALUE)) + // set DELETE_FILE_THRESHOLD to 1 since DVs only produce one delete file per data file + .option(SizeBasedDataRewriter.DELETE_FILE_THRESHOLD, "1") + .execute(); + assertThat(result.rewrittenDataFilesCount()) + .as("Action should rewrite 5 data files") + .isEqualTo(5); + assertThat(result.rewrittenBytesCount()).isGreaterThan(0L).isLessThan(dataSizeBefore); + } else { + Result result = + actions() + .rewriteDataFiles(table) + // do not include any file based on bin pack file size configs + .option(SizeBasedFileRewriter.MIN_FILE_SIZE_BYTES, "0") + .option(RewriteDataFiles.TARGET_FILE_SIZE_BYTES, Long.toString(Long.MAX_VALUE - 1)) + .option(SizeBasedFileRewriter.MAX_FILE_SIZE_BYTES, Long.toString(Long.MAX_VALUE)) + .option(SizeBasedDataRewriter.DELETE_FILE_THRESHOLD, "2") + .execute(); + assertThat(result.rewrittenDataFilesCount()) + .as("Action should rewrite 2 data files") + .isEqualTo(2); + assertThat(result.rewrittenBytesCount()).isGreaterThan(0L).isLessThan(dataSizeBefore); + } + + List actualRecords = currentData(); + assertEquals("Rows must match", expectedRecords, actualRecords); + assertThat(actualRecords).as("7 rows are removed").hasSize(total - 7); + } + + @TestTemplate + public void testRemoveDangledEqualityDeletesPartitionEvolution() { + Table table = + TABLES.create( + SCHEMA, + SPEC, + ImmutableMap.of(TableProperties.FORMAT_VERSION, String.valueOf(formatVersion)), + tableLocation); + + // data seq = 1, write 4 files in 2 partitions + List records1 = + Lists.newArrayList( + new ThreeColumnRecord(1, null, "AAAA"), new ThreeColumnRecord(1, "BBBBBBBBBB", "BBBB")); + writeRecords(records1); + List records2 = + Lists.newArrayList( + new ThreeColumnRecord(0, "CCCCCCCCCC", "CCCC"), + new ThreeColumnRecord(0, "DDDDDDDDDD", "DDDD")); + writeRecords(records2); + table.refresh(); + shouldHaveFiles(table, 4); + + // data seq = 2 & 3, write 2 equality deletes in both partitions + writeEqDeleteRecord(table, "c1", 1, "c3", "AAAA"); + writeEqDeleteRecord(table, "c1", 2, "c3", "CCCC"); + table.refresh(); + Set existingDeletes = TestHelpers.deleteFiles(table); + assertThat(existingDeletes) + .as("Only one equality delete c1=1 is used in query planning") + .hasSize(1); + + // partition evolution + table.refresh(); + table.updateSpec().addField(Expressions.ref("c3")).commit(); + + // data seq = 4, write 2 new data files in both partitions for evolved spec + List records3 = + Lists.newArrayList( + new ThreeColumnRecord(1, "A", "CCCC"), new ThreeColumnRecord(2, "D", "DDDD")); + writeRecords(records3); + + List originalData = currentData(); + + RewriteDataFiles.Result result = + basicRewrite(table) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .filter(Expressions.equal("c1", 1)) + .option(RewriteDataFiles.REMOVE_DANGLING_DELETES, "true") + .execute(); + + existingDeletes = TestHelpers.deleteFiles(table); + assertThat(existingDeletes).as("Shall pruned dangling deletes after rewrite").hasSize(0); + + assertThat(result) + .extracting( + Result::addedDataFilesCount, + Result::rewrittenDataFilesCount, + Result::removedDeleteFilesCount) + .as("Should compact 3 data files into 2 and remove both dangled equality delete file") + .containsExactly(2, 3, 2); + shouldHaveMinSequenceNumberInPartition(table, "data_file.partition.c1 == 1", 5); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 7); + shouldHaveFiles(table, 5); + } + + @TestTemplate + public void testRemoveDangledPositionDeletesPartitionEvolution() throws IOException { + Table table = + TABLES.create( + SCHEMA, + SPEC, + ImmutableMap.of(TableProperties.FORMAT_VERSION, String.valueOf(formatVersion)), + tableLocation); + + // data seq = 1, write 4 files in 2 partitions + writeRecords(2, 2, 2); + List dataFilesBefore = TestHelpers.dataFiles(table, null); + shouldHaveFiles(table, 4); + + DeleteFile deleteFile; + // data seq = 2, write 1 position deletes in c1=1 + DataFile dataFile = dataFilesBefore.get(3); + if (formatVersion >= 3) { + deleteFile = writeDV(table, dataFile.partition(), dataFile.location(), 1).get(0); + } else { + deleteFile = writePosDeletesToFile(table, dataFile, 1).get(0); + } + table.newRowDelta().addDeletes(deleteFile).commit(); + + // partition evolution + table.updateSpec().addField(Expressions.ref("c3")).commit(); + + // data seq = 3, write 1 new data files in c1=1 for evolved spec + writeRecords(1, 1, 1); + shouldHaveFiles(table, 5); + List expectedRecords = currentData(); + + Result result = + actions() + .rewriteDataFiles(table) + .filter(Expressions.equal("c1", 1)) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .option(RewriteDataFiles.REMOVE_DANGLING_DELETES, "true") + .execute(); + + assertThat(result) + .extracting( + Result::addedDataFilesCount, + Result::rewrittenDataFilesCount, + Result::removedDeleteFilesCount) + .as("Should rewrite 2 data files into 1 and remove 1 dangled position delete file") + .containsExactly(1, 2, 1); + shouldHaveMinSequenceNumberInPartition(table, "data_file.partition.c1 == 1", 3); + + shouldHaveSnapshots(table, 5); + assertThat(table.currentSnapshot().summary().get("total-position-deletes")).isEqualTo("0"); + assertEquals("Rows must match", expectedRecords, currentData()); + } + + @TestTemplate + public void testBinPackWithDeleteAllData() throws IOException { + assumeThat(formatVersion).isGreaterThanOrEqualTo(2); + Table table = createTablePartitioned(1, 1, 1); + shouldHaveFiles(table, 1); + table.refresh(); + + List dataFiles = TestHelpers.dataFiles(table); + int total = (int) dataFiles.stream().mapToLong(ContentFile::recordCount).sum(); + + RowDelta rowDelta = table.newRowDelta(); + DataFile dataFile = dataFiles.get(0); + // remove all data + if (formatVersion >= 3) { + writeDV(table, dataFile.partition(), dataFile.location(), total) + .forEach(rowDelta::addDeletes); + } else { + writePosDeletesToFile(table, dataFile, total).forEach(rowDelta::addDeletes); + } + + rowDelta.commit(); + table.refresh(); + List expectedRecords = currentData(); + long dataSizeBefore = testDataSize(table); + + Result result = + actions() + .rewriteDataFiles(table) + .option(SizeBasedDataRewriter.DELETE_FILE_THRESHOLD, "1") + .execute(); + assertThat(result.rewrittenDataFilesCount()).as("Action should rewrite 1 data files").isOne(); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + List actualRecords = currentData(); + assertEquals("Rows must match", expectedRecords, actualRecords); + assertThat(table.currentSnapshot().dataManifests(table.io()).get(0).existingFilesCount()) + .as("Data manifest should not have existing data file") + .isZero(); + + assertThat((long) table.currentSnapshot().dataManifests(table.io()).get(0).deletedFilesCount()) + .as("Data manifest should have 1 delete data file") + .isEqualTo(1L); + + assertThat(table.currentSnapshot().deleteManifests(table.io()).get(0).addedRowsCount()) + .as("Delete manifest added row count should equal total count") + .isEqualTo(total); + } + + @TestTemplate + public void testBinPackWithStartingSequenceNumber() { + assumeThat(formatVersion).isGreaterThanOrEqualTo(2); + Table table = createTablePartitioned(4, 2); + shouldHaveFiles(table, 8); + List expectedRecords = currentData(); + table.refresh(); + long oldSequenceNumber = table.currentSnapshot().sequenceNumber(); + long dataSizeBefore = testDataSize(table); + + Result result = + basicRewrite(table).option(RewriteDataFiles.USE_STARTING_SEQUENCE_NUMBER, "true").execute(); + assertThat(result.rewrittenDataFilesCount()) + .as("Action should rewrite 8 data files") + .isEqualTo(8); + assertThat(result.addedDataFilesCount()).as("Action should add 4 data files").isEqualTo(4); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + shouldHaveFiles(table, 4); + List actualRecords = currentData(); + assertEquals("Rows must match", expectedRecords, actualRecords); + + table.refresh(); + assertThat(table.currentSnapshot().sequenceNumber()) + .as("Table sequence number should be incremented") + .isGreaterThan(oldSequenceNumber); + + Dataset rows = SparkTableUtil.loadMetadataTable(spark, table, MetadataTableType.ENTRIES); + for (Row row : rows.collectAsList()) { + if (row.getInt(0) == 1) { + assertThat(row.getLong(2)) + .as("Expect old sequence number for added entries") + .isEqualTo(oldSequenceNumber); + } + } + } + + @TestTemplate + public void testBinPackWithStartingSequenceNumberV1Compatibility() { + Map properties = ImmutableMap.of(TableProperties.FORMAT_VERSION, "1"); + Table table = createTablePartitioned(4, 2, SCALE, properties); + shouldHaveFiles(table, 8); + List expectedRecords = currentData(); + table.refresh(); + long oldSequenceNumber = table.currentSnapshot().sequenceNumber(); + assertThat(oldSequenceNumber).as("Table sequence number should be 0").isZero(); + long dataSizeBefore = testDataSize(table); + + Result result = + basicRewrite(table).option(RewriteDataFiles.USE_STARTING_SEQUENCE_NUMBER, "true").execute(); + assertThat(result.rewrittenDataFilesCount()) + .as("Action should rewrite 8 data files") + .isEqualTo(8); + assertThat(result.addedDataFilesCount()).as("Action should add 4 data files").isEqualTo(4); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + shouldHaveFiles(table, 4); + List actualRecords = currentData(); + assertEquals("Rows must match", expectedRecords, actualRecords); + + table.refresh(); + assertThat(table.currentSnapshot().sequenceNumber()) + .as("Table sequence number should still be 0") + .isEqualTo(oldSequenceNumber); + + Dataset rows = SparkTableUtil.loadMetadataTable(spark, table, MetadataTableType.ENTRIES); + for (Row row : rows.collectAsList()) { + assertThat(row.getLong(2)) + .as("Expect sequence number 0 for all entries") + .isEqualTo(oldSequenceNumber); + } + } + + @TestTemplate + public void testRewriteLargeTableHasResiduals() { + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).build(); + Map options = + ImmutableMap.of( + TableProperties.FORMAT_VERSION, + String.valueOf(formatVersion), + TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES, + "100"); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + // all records belong to the same partition + List records = Lists.newArrayList(); + for (int i = 0; i < 100; i++) { + records.add(new ThreeColumnRecord(i, String.valueOf(i), String.valueOf(i % 4))); + } + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class); + writeDF(df); + + List expectedRecords = currentData(); + + table.refresh(); + + CloseableIterable tasks = + table.newScan().ignoreResiduals().filter(Expressions.equal("c3", "0")).planFiles(); + for (FileScanTask task : tasks) { + assertThat(task.residual()) + .as("Residuals must be ignored") + .isEqualTo(Expressions.alwaysTrue()); + } + + shouldHaveFiles(table, 2); + + long dataSizeBefore = testDataSize(table); + Result result = basicRewrite(table).filter(Expressions.equal("c3", "0")).execute(); + assertThat(result.rewrittenDataFilesCount()) + .as("Action should rewrite 2 data files") + .isEqualTo(2); + assertThat(result.addedDataFilesCount()).as("Action should add 1 data file").isOne(); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + List actualRecords = currentData(); + + assertEquals("Rows must match", expectedRecords, actualRecords); + } + + @TestTemplate + public void testBinPackSplitLargeFile() { + Table table = createTable(1); + shouldHaveFiles(table, 1); + + List expectedRecords = currentData(); + long targetSize = testDataSize(table) / 2; + + long dataSizeBefore = testDataSize(table); + Result result = + basicRewrite(table) + .option(RewriteDataFiles.TARGET_FILE_SIZE_BYTES, Long.toString(targetSize)) + .option(SizeBasedFileRewriter.MAX_FILE_SIZE_BYTES, Long.toString(targetSize * 2 - 2000)) + .execute(); + + assertThat(result.rewrittenDataFilesCount()).as("Action should delete 1 data files").isOne(); + assertThat(result.addedDataFilesCount()).as("Action should add 2 data files").isEqualTo(2); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + shouldHaveFiles(table, 2); + + List actualRecords = currentData(); + assertEquals("Rows must match", expectedRecords, actualRecords); + } + + @TestTemplate + public void testBinPackCombineMixedFiles() { + Table table = createTable(1); // 400000 + shouldHaveFiles(table, 1); + + // Add one more small file, and one large file + writeRecords(1, SCALE); + writeRecords(1, SCALE * 3); + shouldHaveFiles(table, 3); + + List expectedRecords = currentData(); + + int targetSize = averageFileSize(table); + + long dataSizeBefore = testDataSize(table); + Result result = + basicRewrite(table) + .option(RewriteDataFiles.TARGET_FILE_SIZE_BYTES, Integer.toString(targetSize + 1000)) + .option(SizeBasedFileRewriter.MAX_FILE_SIZE_BYTES, Integer.toString(targetSize + 80000)) + .option(SizeBasedFileRewriter.MIN_FILE_SIZE_BYTES, Integer.toString(targetSize - 1000)) + .execute(); + + assertThat(result.rewrittenDataFilesCount()) + .as("Action should delete 3 data files") + .isEqualTo(3); + // Should Split the big files into 3 pieces, one of which should be combined with the two + // smaller files + assertThat(result.addedDataFilesCount()).as("Action should add 3 data files").isEqualTo(3); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + shouldHaveFiles(table, 3); + + List actualRecords = currentData(); + assertEquals("Rows must match", expectedRecords, actualRecords); + } + + @TestTemplate + public void testBinPackCombineMediumFiles() { + Table table = createTable(4); + shouldHaveFiles(table, 4); + + List expectedRecords = currentData(); + int targetSize = ((int) testDataSize(table) / 3); + // The test is to see if we can combine parts of files to make files of the correct size + + long dataSizeBefore = testDataSize(table); + Result result = + basicRewrite(table) + .option(RewriteDataFiles.TARGET_FILE_SIZE_BYTES, Integer.toString(targetSize)) + .option( + SizeBasedFileRewriter.MAX_FILE_SIZE_BYTES, + Integer.toString((int) (targetSize * 1.8))) + .option( + SizeBasedFileRewriter.MIN_FILE_SIZE_BYTES, + Integer.toString(targetSize - 100)) // All files too small + .execute(); + + assertThat(result.rewrittenDataFilesCount()) + .as("Action should delete 4 data files") + .isEqualTo(4); + assertThat(result.addedDataFilesCount()).as("Action should add 3 data files").isEqualTo(3); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + shouldHaveFiles(table, 3); + + List actualRecords = currentData(); + assertEquals("Rows must match", expectedRecords, actualRecords); + } + + @TestTemplate + public void testPartialProgressEnabled() { + Table table = createTable(20); + int fileSize = averageFileSize(table); + + table.updateProperties().set(COMMIT_NUM_RETRIES, "10").commit(); + + List originalData = currentData(); + long dataSizeBefore = testDataSize(table); + + // Perform a rewrite but only allow 2 files to be compacted at a time + RewriteDataFiles.Result result = + basicRewrite(table) + .option(RewriteDataFiles.PARTIAL_PROGRESS_ENABLED, "true") + .option( + RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000)) + .option(RewriteDataFiles.PARTIAL_PROGRESS_MAX_COMMITS, "10") + .execute(); + + assertThat(result.rewriteResults()).as("Should have 10 fileGroups").hasSize(10); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + table.refresh(); + + shouldHaveSnapshots(table, 11); + shouldHaveACleanCache(table); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + } + + @TestTemplate + public void testMultipleGroups() { + Table table = createTable(20); + int fileSize = averageFileSize(table); + + List originalData = currentData(); + long dataSizeBefore = testDataSize(table); + + // Perform a rewrite but only allow 2 files to be compacted at a time + RewriteDataFiles.Result result = + basicRewrite(table) + .option( + RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000)) + .option(SizeBasedFileRewriter.MIN_INPUT_FILES, "1") + .execute(); + + assertThat(result.rewriteResults()).as("Should have 10 fileGroups").hasSize(10); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); + shouldHaveACleanCache(table); + } + + @TestTemplate + public void testPartialProgressMaxCommits() { + Table table = createTable(20); + int fileSize = averageFileSize(table); + + List originalData = currentData(); + long dataSizeBefore = testDataSize(table); + + // Perform a rewrite but only allow 2 files to be compacted at a time + RewriteDataFiles.Result result = + basicRewrite(table) + .option( + RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000)) + .option(RewriteDataFiles.PARTIAL_PROGRESS_ENABLED, "true") + .option(RewriteDataFiles.PARTIAL_PROGRESS_MAX_COMMITS, "3") + .execute(); + + assertThat(result.rewriteResults()).as("Should have 10 fileGroups").hasSize(10); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 4); + shouldHaveACleanCache(table); + } + + @TestTemplate + public void testSingleCommitWithRewriteFailure() { + Table table = createTable(20); + int fileSize = averageFileSize(table); + + List originalData = currentData(); + + RewriteDataFilesSparkAction realRewrite = + basicRewrite(table) + .option( + RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000)); + + RewriteDataFilesSparkAction spyRewrite = Mockito.spy(realRewrite); + + // Fail groups 1, 3, and 7 during rewrite + GroupInfoMatcher failGroup = new GroupInfoMatcher(1, 3, 7); + doThrow(new RuntimeException("Rewrite Failed")) + .when(spyRewrite) + .rewriteFiles(any(), argThat(failGroup)); + + assertThatThrownBy(spyRewrite::execute) + .isInstanceOf(RuntimeException.class) + .hasMessage("Rewrite Failed"); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 1); + shouldHaveNoOrphans(table); + shouldHaveACleanCache(table); + } + + @TestTemplate + public void testSingleCommitWithCommitFailure() { + Table table = createTable(20); + int fileSize = averageFileSize(table); + + List originalData = currentData(); + + RewriteDataFilesSparkAction realRewrite = + basicRewrite(table) + .option( + RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000)); + + RewriteDataFilesSparkAction spyRewrite = spy(realRewrite); + RewriteDataFilesCommitManager util = spy(new RewriteDataFilesCommitManager(table)); + + // Fail to commit + doThrow(new CommitFailedException("Commit Failure")).when(util).commitFileGroups(any()); + + doReturn(util).when(spyRewrite).commitManager(table.currentSnapshot().snapshotId()); + + assertThatThrownBy(spyRewrite::execute) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining("Cannot commit rewrite"); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 1); + shouldHaveNoOrphans(table); + shouldHaveACleanCache(table); + } + + @TestTemplate + public void testCommitFailsWithUncleanableFailure() { + Table table = createTable(20); + int fileSize = averageFileSize(table); + + List originalData = currentData(); + + RewriteDataFilesSparkAction realRewrite = + basicRewrite(table) + .option( + RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000)); + + RewriteDataFilesSparkAction spyRewrite = spy(realRewrite); + RewriteDataFilesCommitManager util = spy(new RewriteDataFilesCommitManager(table)); + + // Fail to commit with an arbitrary failure and validate that orphans are not cleaned up + doThrow(new RuntimeException("Arbitrary Failure")).when(util).commitFileGroups(any()); + + doReturn(util).when(spyRewrite).commitManager(table.currentSnapshot().snapshotId()); + + assertThatThrownBy(spyRewrite::execute) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining("Arbitrary Failure"); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 1); + shouldHaveOrphans(table); + shouldHaveACleanCache(table); + } + + @TestTemplate + public void testParallelSingleCommitWithRewriteFailure() { + Table table = createTable(20); + int fileSize = averageFileSize(table); + + List originalData = currentData(); + + RewriteDataFilesSparkAction realRewrite = + basicRewrite(table) + .option( + RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000)) + .option(RewriteDataFiles.MAX_CONCURRENT_FILE_GROUP_REWRITES, "3"); + + RewriteDataFilesSparkAction spyRewrite = Mockito.spy(realRewrite); + + // Fail groups 1, 3, and 7 during rewrite + GroupInfoMatcher failGroup = new GroupInfoMatcher(1, 3, 7); + doThrow(new CommitFailedException("Rewrite Failed")) + .when(spyRewrite) + .rewriteFiles(any(), argThat(failGroup)); + + assertThatThrownBy(spyRewrite::execute) + .isInstanceOf(CommitFailedException.class) + .hasMessage("Rewrite Failed"); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 1); + shouldHaveNoOrphans(table); + shouldHaveACleanCache(table); + } + + @TestTemplate + public void testPartialProgressWithRewriteFailure() { + Table table = createTable(20); + int fileSize = averageFileSize(table); + + List originalData = currentData(); + long dataSizeBefore = testDataSize(table); + + RewriteDataFilesSparkAction realRewrite = + basicRewrite(table) + .option( + RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000)) + .option(RewriteDataFiles.PARTIAL_PROGRESS_ENABLED, "true") + .option(RewriteDataFiles.PARTIAL_PROGRESS_MAX_COMMITS, "3"); + + RewriteDataFilesSparkAction spyRewrite = Mockito.spy(realRewrite); + + // Fail groups 1, 3, and 7 during rewrite + GroupInfoMatcher failGroup = new GroupInfoMatcher(1, 3, 7); + doThrow(new RuntimeException("Rewrite Failed")) + .when(spyRewrite) + .rewriteFiles(any(), argThat(failGroup)); + + RewriteDataFiles.Result result = spyRewrite.execute(); + + assertThat(result.rewriteResults()).hasSize(7); + assertThat(result.rewriteFailures()).hasSize(3); + assertThat(result.failedDataFilesCount()).isEqualTo(6); + assertThat(result.rewrittenBytesCount()).isGreaterThan(0L).isLessThan(dataSizeBefore); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + // With 10 original groups and Max Commits of 3, we should have commits with 4, 4, and 2. + // removing 3 groups leaves us with only 2 new commits, 4 and 3 + shouldHaveSnapshots(table, 3); + shouldHaveNoOrphans(table); + shouldHaveACleanCache(table); + } + + @TestTemplate + public void testParallelPartialProgressWithRewriteFailure() { + Table table = createTable(20); + int fileSize = averageFileSize(table); + + List originalData = currentData(); + long dataSizeBefore = testDataSize(table); + + RewriteDataFilesSparkAction realRewrite = + basicRewrite(table) + .option( + RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000)) + .option(RewriteDataFiles.MAX_CONCURRENT_FILE_GROUP_REWRITES, "3") + .option(RewriteDataFiles.PARTIAL_PROGRESS_ENABLED, "true") + .option(RewriteDataFiles.PARTIAL_PROGRESS_MAX_COMMITS, "3"); + + RewriteDataFilesSparkAction spyRewrite = Mockito.spy(realRewrite); + + // Fail groups 1, 3, and 7 during rewrite + GroupInfoMatcher failGroup = new GroupInfoMatcher(1, 3, 7); + doThrow(new RuntimeException("Rewrite Failed")) + .when(spyRewrite) + .rewriteFiles(any(), argThat(failGroup)); + + RewriteDataFiles.Result result = spyRewrite.execute(); + + assertThat(result.rewriteResults()).hasSize(7); + assertThat(result.rewriteFailures()).hasSize(3); + assertThat(result.failedDataFilesCount()).isEqualTo(6); + assertThat(result.rewrittenBytesCount()).isGreaterThan(0L).isLessThan(dataSizeBefore); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + // With 10 original groups and max commits of 3, we have 4 groups per commit. + // Removing 3 groups, we are left with 4 groups and 3 groups in two commits. + shouldHaveSnapshots(table, 3); + shouldHaveNoOrphans(table); + shouldHaveACleanCache(table); + } + + @TestTemplate + public void testParallelPartialProgressWithCommitFailure() { + Table table = createTable(20); + int fileSize = averageFileSize(table); + + List originalData = currentData(); + long dataSizeBefore = testDataSize(table); + + RewriteDataFilesSparkAction realRewrite = + basicRewrite(table) + .option( + RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000)) + .option(RewriteDataFiles.MAX_CONCURRENT_FILE_GROUP_REWRITES, "3") + .option(RewriteDataFiles.PARTIAL_PROGRESS_ENABLED, "true") + .option(RewriteDataFiles.PARTIAL_PROGRESS_MAX_COMMITS, "3"); + + RewriteDataFilesSparkAction spyRewrite = spy(realRewrite); + RewriteDataFilesCommitManager util = spy(new RewriteDataFilesCommitManager(table)); + + // First and Third commits work, second does not + doCallRealMethod() + .doThrow(new CommitFailedException("Commit Failed")) + .doCallRealMethod() + .when(util) + .commitFileGroups(any()); + + doReturn(util).when(spyRewrite).commitManager(table.currentSnapshot().snapshotId()); + + RewriteDataFiles.Result result = spyRewrite.execute(); + + // Commit 1: 4/4 + Commit 2 failed 0/4 + Commit 3: 2/2 == 6 out of 10 total groups committed + assertThat(result.rewriteResults()).as("Should have 6 fileGroups").hasSize(6); + assertThat(result.rewrittenBytesCount()).isGreaterThan(0L).isLessThan(dataSizeBefore); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + // Only 2 new commits because we broke one + shouldHaveSnapshots(table, 3); + shouldHaveNoOrphans(table); + shouldHaveACleanCache(table); + } + + @TestTemplate + public void testParallelPartialProgressWithMaxFailedCommits() { + Table table = createTable(20); + int fileSize = averageFileSize(table); + + List originalData = currentData(); + + RewriteDataFilesSparkAction realRewrite = + basicRewrite(table) + .option( + RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000)) + .option(RewriteDataFiles.MAX_CONCURRENT_FILE_GROUP_REWRITES, "3") + .option(RewriteDataFiles.PARTIAL_PROGRESS_ENABLED, "true") + .option(RewriteDataFiles.PARTIAL_PROGRESS_MAX_COMMITS, "3") + .option(RewriteDataFiles.PARTIAL_PROGRESS_MAX_FAILED_COMMITS, "0"); + + RewriteDataFilesSparkAction spyRewrite = Mockito.spy(realRewrite); + + // Fail groups 1, 3, and 7 during rewrite + GroupInfoMatcher failGroup = new GroupInfoMatcher(1, 3, 7); + doThrow(new RuntimeException("Rewrite Failed")) + .when(spyRewrite) + .rewriteFiles(any(), argThat(failGroup)); + + assertThatThrownBy(() -> spyRewrite.execute()) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining( + "1 rewrite commits failed. This is more than the maximum allowed failures of 0"); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + // With 10 original groups and max commits of 3, we have 4 groups per commit. + // Removing 3 groups, we are left with 4 groups and 3 groups in two commits. + // Adding max allowed failed commits doesn't change the number of successful commits. + shouldHaveSnapshots(table, 3); + shouldHaveNoOrphans(table); + shouldHaveACleanCache(table); + } + + @TestTemplate + public void testInvalidOptions() { + Table table = createTable(20); + + assertThatThrownBy( + () -> + basicRewrite(table) + .option(RewriteDataFiles.PARTIAL_PROGRESS_ENABLED, "true") + .option(RewriteDataFiles.PARTIAL_PROGRESS_MAX_COMMITS, "-5") + .execute()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Cannot set partial-progress.max-commits to -5, " + + "the value must be positive when partial-progress.enabled is true"); + + assertThatThrownBy( + () -> + basicRewrite(table) + .option(RewriteDataFiles.MAX_CONCURRENT_FILE_GROUP_REWRITES, "-5") + .execute()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Cannot set max-concurrent-file-group-rewrites to -5, the value must be positive."); + + assertThatThrownBy(() -> basicRewrite(table).option("foobarity", "-5").execute()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Cannot use options [foobarity], they are not supported by the action or the rewriter BIN-PACK"); + + assertThatThrownBy( + () -> basicRewrite(table).option(RewriteDataFiles.REWRITE_JOB_ORDER, "foo").execute()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Invalid rewrite job order name: foo"); + + assertThatThrownBy( + () -> + basicRewrite(table) + .sort(SortOrder.builderFor(table.schema()).asc("c2").build()) + .option(SparkShufflingDataRewriter.SHUFFLE_PARTITIONS_PER_FILE, "5") + .execute()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("requires enabling Iceberg Spark session extensions"); + } + + @TestTemplate + public void testSortMultipleGroups() { + Table table = createTable(20); + shouldHaveFiles(table, 20); + table.replaceSortOrder().asc("c2").commit(); + shouldHaveLastCommitUnsorted(table, "c2"); + int fileSize = averageFileSize(table); + + List originalData = currentData(); + long dataSizeBefore = testDataSize(table); + + // Perform a rewrite but only allow 2 files to be compacted at a time + RewriteDataFiles.Result result = + basicRewrite(table) + .sort() + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .option( + RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000)) + .execute(); + + assertThat(result.rewriteResults()).as("Should have 10 fileGroups").hasSize(10); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); + shouldHaveACleanCache(table); + } + + @TestTemplate + public void testSimpleSort() { + Table table = createTable(20); + shouldHaveFiles(table, 20); + table.replaceSortOrder().asc("c2").commit(); + shouldHaveLastCommitUnsorted(table, "c2"); + + List originalData = currentData(); + long dataSizeBefore = testDataSize(table); + + RewriteDataFiles.Result result = + basicRewrite(table) + .sort() + .option(SizeBasedFileRewriter.MIN_INPUT_FILES, "1") + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .option( + RewriteDataFiles.TARGET_FILE_SIZE_BYTES, Integer.toString(averageFileSize(table))) + .execute(); + + assertThat(result.rewriteResults()).as("Should have 1 fileGroups").hasSize(1); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); + shouldHaveACleanCache(table); + shouldHaveMultipleFiles(table); + shouldHaveLastCommitSorted(table, "c2"); + } + + @TestTemplate + public void testSortAfterPartitionChange() { + Table table = createTable(20); + shouldHaveFiles(table, 20); + table.updateSpec().addField(Expressions.bucket("c1", 4)).commit(); + table.replaceSortOrder().asc("c2").commit(); + shouldHaveLastCommitUnsorted(table, "c2"); + + List originalData = currentData(); + long dataSizeBefore = testDataSize(table); + + RewriteDataFiles.Result result = + basicRewrite(table) + .sort() + .option(SizeBasedFileRewriter.MIN_INPUT_FILES, "1") + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .option( + RewriteDataFiles.TARGET_FILE_SIZE_BYTES, Integer.toString(averageFileSize(table))) + .execute(); + + assertThat(result.rewriteResults()) + .as("Should have 1 fileGroups because all files were not correctly partitioned") + .hasSize(1); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); + shouldHaveACleanCache(table); + shouldHaveMultipleFiles(table); + shouldHaveLastCommitSorted(table, "c2"); + } + + @TestTemplate + public void testSortCustomSortOrder() { + Table table = createTable(20); + shouldHaveLastCommitUnsorted(table, "c2"); + shouldHaveFiles(table, 20); + + List originalData = currentData(); + long dataSizeBefore = testDataSize(table); + + RewriteDataFiles.Result result = + basicRewrite(table) + .sort(SortOrder.builderFor(table.schema()).asc("c2").build()) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .option( + RewriteDataFiles.TARGET_FILE_SIZE_BYTES, Integer.toString(averageFileSize(table))) + .execute(); + + assertThat(result.rewriteResults()).as("Should have 1 fileGroups").hasSize(1); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); + shouldHaveACleanCache(table); + shouldHaveMultipleFiles(table); + shouldHaveLastCommitSorted(table, "c2"); + } + + @TestTemplate + public void testSortCustomSortOrderRequiresRepartition() { + int partitions = 4; + Table table = createTable(); + writeRecords(20, SCALE, partitions); + shouldHaveLastCommitUnsorted(table, "c3"); + + // Add a partition column so this requires repartitioning + table.updateSpec().addField("c1").commit(); + // Add a sort order which our repartitioning needs to ignore + table.replaceSortOrder().asc("c2").apply(); + shouldHaveFiles(table, 20); + + List originalData = currentData(); + long dataSizeBefore = testDataSize(table); + + RewriteDataFiles.Result result = + basicRewrite(table) + .sort(SortOrder.builderFor(table.schema()).asc("c3").build()) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .option( + RewriteDataFiles.TARGET_FILE_SIZE_BYTES, + Integer.toString(averageFileSize(table) / partitions)) + .execute(); + + assertThat(result.rewriteResults()).as("Should have 1 fileGroups").hasSize(1); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); + shouldHaveACleanCache(table); + shouldHaveMultipleFiles(table); + shouldHaveLastCommitUnsorted(table, "c2"); + shouldHaveLastCommitSorted(table, "c3"); + } + + @TestTemplate + public void testAutoSortShuffleOutput() { + Table table = createTable(20); + shouldHaveLastCommitUnsorted(table, "c2"); + shouldHaveFiles(table, 20); + + List originalData = currentData(); + long dataSizeBefore = testDataSize(table); + + RewriteDataFiles.Result result = + basicRewrite(table) + .sort(SortOrder.builderFor(table.schema()).asc("c2").build()) + .option( + SizeBasedFileRewriter.MAX_FILE_SIZE_BYTES, + Integer.toString((averageFileSize(table) / 2) + 2)) + // Divide files in 2 + .option( + RewriteDataFiles.TARGET_FILE_SIZE_BYTES, + Integer.toString(averageFileSize(table) / 2)) + .option(SizeBasedFileRewriter.MIN_INPUT_FILES, "1") + .execute(); + + assertThat(result.rewriteResults()).as("Should have 1 fileGroups").hasSize(1); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + assertThat(result.rewriteResults()).as("Should have 1 fileGroups").hasSize(1); + assertThat(table.currentSnapshot().addedDataFiles(table.io())) + .as("Should have written 40+ files") + .hasSizeGreaterThanOrEqualTo(40); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); + shouldHaveACleanCache(table); + shouldHaveMultipleFiles(table); + shouldHaveLastCommitSorted(table, "c2"); + } + + @TestTemplate + public void testCommitStateUnknownException() { + Table table = createTable(20); + shouldHaveFiles(table, 20); + + List originalData = currentData(); + + RewriteDataFilesSparkAction action = basicRewrite(table); + RewriteDataFilesSparkAction spyAction = spy(action); + RewriteDataFilesCommitManager util = spy(new RewriteDataFilesCommitManager(table)); + + doAnswer( + invocationOnMock -> { + invocationOnMock.callRealMethod(); + throw new CommitStateUnknownException(new RuntimeException("Unknown State")); + }) + .when(util) + .commitFileGroups(any()); + + doReturn(util).when(spyAction).commitManager(table.currentSnapshot().snapshotId()); + + assertThatThrownBy(spyAction::execute) + .isInstanceOf(CommitStateUnknownException.class) + .hasMessageStartingWith( + "Unknown State\n" + "Cannot determine whether the commit was successful or not"); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); // Commit actually Succeeded + } + + @TestTemplate + public void testZOrderSort() { + int originalFiles = 20; + Table table = createTable(originalFiles); + shouldHaveLastCommitUnsorted(table, "c2"); + shouldHaveFiles(table, originalFiles); + + List originalData = currentData(); + double originalFilesC2 = percentFilesRequired(table, "c2", "foo23"); + double originalFilesC3 = percentFilesRequired(table, "c3", "bar21"); + double originalFilesC2C3 = + percentFilesRequired(table, new String[] {"c2", "c3"}, new String[] {"foo23", "bar23"}); + + assertThat(originalFilesC2).as("Should require all files to scan c2").isGreaterThan(0.99); + assertThat(originalFilesC3).as("Should require all files to scan c3").isGreaterThan(0.99); + + long dataSizeBefore = testDataSize(table); + RewriteDataFiles.Result result = + basicRewrite(table) + .zOrder("c2", "c3") + .option( + SizeBasedFileRewriter.MAX_FILE_SIZE_BYTES, + Integer.toString((averageFileSize(table) / 2) + 2)) + // Divide files in 2 + .option( + RewriteDataFiles.TARGET_FILE_SIZE_BYTES, + Integer.toString(averageFileSize(table) / 2)) + .option(SizeBasedFileRewriter.MIN_INPUT_FILES, "1") + .execute(); + + assertThat(result.rewriteResults()).as("Should have 1 fileGroups").hasSize(1); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + assertThat(table.currentSnapshot().addedDataFiles(table.io())) + .as("Should have written 40+ files") + .hasSizeGreaterThanOrEqualTo(40); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); + shouldHaveACleanCache(table); + + double filesScannedC2 = percentFilesRequired(table, "c2", "foo23"); + double filesScannedC3 = percentFilesRequired(table, "c3", "bar21"); + double filesScannedC2C3 = + percentFilesRequired(table, new String[] {"c2", "c3"}, new String[] {"foo23", "bar23"}); + + assertThat(originalFilesC2) + .as("Should have reduced the number of files required for c2") + .isGreaterThan(filesScannedC2); + assertThat(originalFilesC3) + .as("Should have reduced the number of files required for c3") + .isGreaterThan(filesScannedC3); + assertThat(originalFilesC2C3) + .as("Should have reduced the number of files required for c2,c3 predicate") + .isGreaterThan(filesScannedC2C3); + } + + @TestTemplate + public void testZOrderAllTypesSort() { + spark.conf().set("spark.sql.ansi.enabled", "false"); + Table table = createTypeTestTable(); + shouldHaveFiles(table, 10); + + List originalRaw = + spark.read().format("iceberg").load(tableLocation).sort("longCol").collectAsList(); + List originalData = rowsToJava(originalRaw); + long dataSizeBefore = testDataSize(table); + + // TODO add in UUID when it is supported in Spark + RewriteDataFiles.Result result = + basicRewrite(table) + .zOrder( + "longCol", + "intCol", + "floatCol", + "doubleCol", + "dateCol", + "timestampCol", + "stringCol", + "binaryCol", + "booleanCol") + .option(SizeBasedFileRewriter.MIN_INPUT_FILES, "1") + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .execute(); + + assertThat(result.rewriteResults()).as("Should have 1 fileGroups").hasSize(1); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + assertThat(table.currentSnapshot().addedDataFiles(table.io())) + .as("Should have written 1 file") + .hasSize(1); + + table.refresh(); + + List postRaw = + spark.read().format("iceberg").load(tableLocation).sort("longCol").collectAsList(); + List postRewriteData = rowsToJava(postRaw); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); + shouldHaveACleanCache(table); + } + + @TestTemplate + public void testInvalidAPIUsage() { + Table table = createTable(1); + + SortOrder sortOrder = SortOrder.builderFor(table.schema()).asc("c2").build(); + + assertThatThrownBy(() -> actions().rewriteDataFiles(table).binPack().sort()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Must use only one rewriter type (bin-pack, sort, zorder)"); + + assertThatThrownBy(() -> actions().rewriteDataFiles(table).sort(sortOrder).binPack()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Must use only one rewriter type (bin-pack, sort, zorder)"); + + assertThatThrownBy(() -> actions().rewriteDataFiles(table).sort(sortOrder).binPack()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Must use only one rewriter type (bin-pack, sort, zorder)"); + } + + @TestTemplate + public void testRewriteJobOrderBytesAsc() { + assumeThat(formatVersion).isGreaterThanOrEqualTo(2); + Table table = createTablePartitioned(4, 2); + writeRecords(1, SCALE, 1); + writeRecords(2, SCALE, 2); + writeRecords(3, SCALE, 3); + writeRecords(4, SCALE, 4); + + RewriteDataFilesSparkAction basicRewrite = basicRewrite(table).binPack(); + List expected = + toGroupStream(table, basicRewrite) + .mapToLong(RewriteFileGroup::sizeInBytes) + .boxed() + .collect(Collectors.toList()); + + RewriteDataFilesSparkAction jobOrderRewrite = + basicRewrite(table) + .option(RewriteDataFiles.REWRITE_JOB_ORDER, RewriteJobOrder.BYTES_ASC.orderName()) + .binPack(); + List actual = + toGroupStream(table, jobOrderRewrite) + .mapToLong(RewriteFileGroup::sizeInBytes) + .boxed() + .collect(Collectors.toList()); + + expected.sort(Comparator.naturalOrder()); + assertThat(actual).as("Size in bytes order should be ascending").isEqualTo(expected); + Collections.reverse(expected); + assertThat(actual).as("Size in bytes order should not be descending").isNotEqualTo(expected); + } + + @TestTemplate + public void testRewriteJobOrderBytesDesc() { + assumeThat(formatVersion).isGreaterThanOrEqualTo(2); + Table table = createTablePartitioned(4, 2); + writeRecords(1, SCALE, 1); + writeRecords(2, SCALE, 2); + writeRecords(3, SCALE, 3); + writeRecords(4, SCALE, 4); + + RewriteDataFilesSparkAction basicRewrite = basicRewrite(table).binPack(); + List expected = + toGroupStream(table, basicRewrite) + .mapToLong(RewriteFileGroup::sizeInBytes) + .boxed() + .collect(Collectors.toList()); + + RewriteDataFilesSparkAction jobOrderRewrite = + basicRewrite(table) + .option(RewriteDataFiles.REWRITE_JOB_ORDER, RewriteJobOrder.BYTES_DESC.orderName()) + .binPack(); + List actual = + toGroupStream(table, jobOrderRewrite) + .mapToLong(RewriteFileGroup::sizeInBytes) + .boxed() + .collect(Collectors.toList()); + + expected.sort(Comparator.reverseOrder()); + assertThat(actual).as("Size in bytes order should be descending").isEqualTo(expected); + Collections.reverse(expected); + assertThat(actual).as("Size in bytes order should not be ascending").isNotEqualTo(expected); + } + + @TestTemplate + public void testRewriteJobOrderFilesAsc() { + assumeThat(formatVersion).isGreaterThanOrEqualTo(2); + Table table = createTablePartitioned(4, 2); + writeRecords(1, SCALE, 1); + writeRecords(2, SCALE, 2); + writeRecords(3, SCALE, 3); + writeRecords(4, SCALE, 4); + + RewriteDataFilesSparkAction basicRewrite = basicRewrite(table).binPack(); + List expected = + toGroupStream(table, basicRewrite) + .mapToLong(RewriteFileGroup::numFiles) + .boxed() + .collect(Collectors.toList()); + + RewriteDataFilesSparkAction jobOrderRewrite = + basicRewrite(table) + .option(RewriteDataFiles.REWRITE_JOB_ORDER, RewriteJobOrder.FILES_ASC.orderName()) + .binPack(); + List actual = + toGroupStream(table, jobOrderRewrite) + .mapToLong(RewriteFileGroup::numFiles) + .boxed() + .collect(Collectors.toList()); + + expected.sort(Comparator.naturalOrder()); + assertThat(actual).as("Number of files order should be ascending").isEqualTo(expected); + Collections.reverse(expected); + assertThat(actual).as("Number of files order should not be descending").isNotEqualTo(expected); + } + + @TestTemplate + public void testRewriteJobOrderFilesDesc() { + assumeThat(formatVersion).isGreaterThanOrEqualTo(2); + Table table = createTablePartitioned(4, 2); + writeRecords(1, SCALE, 1); + writeRecords(2, SCALE, 2); + writeRecords(3, SCALE, 3); + writeRecords(4, SCALE, 4); + + RewriteDataFilesSparkAction basicRewrite = basicRewrite(table).binPack(); + List expected = + toGroupStream(table, basicRewrite) + .mapToLong(RewriteFileGroup::numFiles) + .boxed() + .collect(Collectors.toList()); + + RewriteDataFilesSparkAction jobOrderRewrite = + basicRewrite(table) + .option(RewriteDataFiles.REWRITE_JOB_ORDER, RewriteJobOrder.FILES_DESC.orderName()) + .binPack(); + List actual = + toGroupStream(table, jobOrderRewrite) + .mapToLong(RewriteFileGroup::numFiles) + .boxed() + .collect(Collectors.toList()); + + expected.sort(Comparator.reverseOrder()); + assertThat(actual).as("Number of files order should be descending").isEqualTo(expected); + Collections.reverse(expected); + assertThat(actual).as("Number of files order should not be ascending").isNotEqualTo(expected); + } + + @TestTemplate + public void testSnapshotProperty() { + Table table = createTable(4); + Result ignored = basicRewrite(table).snapshotProperty("key", "value").execute(); + assertThat(table.currentSnapshot().summary()) + .containsAllEntriesOf(ImmutableMap.of("key", "value")); + // make sure internal produced properties are not lost + String[] commitMetricsKeys = + new String[] { + SnapshotSummary.ADDED_FILES_PROP, + SnapshotSummary.DELETED_FILES_PROP, + SnapshotSummary.TOTAL_DATA_FILES_PROP, + SnapshotSummary.CHANGED_PARTITION_COUNT_PROP + }; + assertThat(table.currentSnapshot().summary()).containsKeys(commitMetricsKeys); + } + + @TestTemplate + public void testBinPackRewriterWithSpecificUnparitionedOutputSpec() { + Table table = createTable(10); + shouldHaveFiles(table, 10); + int outputSpecId = table.spec().specId(); + table.updateSpec().addField(Expressions.truncate("c2", 2)).commit(); + + long dataSizeBefore = testDataSize(table); + long count = currentData().size(); + + RewriteDataFiles.Result result = + basicRewrite(table) + .option(RewriteDataFiles.OUTPUT_SPEC_ID, String.valueOf(outputSpecId)) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .binPack() + .execute(); + + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + assertThat(currentData().size()).isEqualTo(count); + shouldRewriteDataFilesWithPartitionSpec(table, outputSpecId); + } + + @TestTemplate + public void testBinPackRewriterWithSpecificOutputSpec() { + Table table = createTable(10); + shouldHaveFiles(table, 10); + table.updateSpec().addField(Expressions.truncate("c2", 2)).commit(); + int outputSpecId = table.spec().specId(); + table.updateSpec().addField(Expressions.bucket("c3", 2)).commit(); + + long dataSizeBefore = testDataSize(table); + long count = currentData().size(); + + RewriteDataFiles.Result result = + basicRewrite(table) + .option(RewriteDataFiles.OUTPUT_SPEC_ID, String.valueOf(outputSpecId)) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .binPack() + .execute(); + + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + assertThat(currentData().size()).isEqualTo(count); + shouldRewriteDataFilesWithPartitionSpec(table, outputSpecId); + } + + @TestTemplate + public void testBinpackRewriteWithInvalidOutputSpecId() { + Table table = createTable(10); + shouldHaveFiles(table, 10); + assertThatThrownBy( + () -> + actions() + .rewriteDataFiles(table) + .option(RewriteDataFiles.OUTPUT_SPEC_ID, String.valueOf(1234)) + .binPack() + .execute()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Cannot use output spec id 1234 because the table does not contain a reference to this spec-id."); + } + + @TestTemplate + public void testSortRewriterWithSpecificOutputSpecId() { + Table table = createTable(10); + shouldHaveFiles(table, 10); + table.updateSpec().addField(Expressions.truncate("c2", 2)).commit(); + int outputSpecId = table.spec().specId(); + table.updateSpec().addField(Expressions.bucket("c3", 2)).commit(); + + long dataSizeBefore = testDataSize(table); + long count = currentData().size(); + + RewriteDataFiles.Result result = + basicRewrite(table) + .option(RewriteDataFiles.OUTPUT_SPEC_ID, String.valueOf(outputSpecId)) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .sort(SortOrder.builderFor(table.schema()).asc("c2").asc("c3").build()) + .execute(); + + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + assertThat(currentData().size()).isEqualTo(count); + shouldRewriteDataFilesWithPartitionSpec(table, outputSpecId); + } + + @TestTemplate + public void testZOrderRewriteWithSpecificOutputSpecId() { + Table table = createTable(10); + shouldHaveFiles(table, 10); + table.updateSpec().addField(Expressions.truncate("c2", 2)).commit(); + int outputSpecId = table.spec().specId(); + table.updateSpec().addField(Expressions.bucket("c3", 2)).commit(); + + long dataSizeBefore = testDataSize(table); + long count = currentData().size(); + + RewriteDataFiles.Result result = + basicRewrite(table) + .option(RewriteDataFiles.OUTPUT_SPEC_ID, String.valueOf(outputSpecId)) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .zOrder("c2", "c3") + .execute(); + + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + assertThat(currentData().size()).isEqualTo(count); + shouldRewriteDataFilesWithPartitionSpec(table, outputSpecId); + } + + protected void shouldRewriteDataFilesWithPartitionSpec(Table table, int outputSpecId) { + List rewrittenFiles = currentDataFiles(table); + assertThat(rewrittenFiles).allMatch(file -> file.specId() == outputSpecId); + assertThat(rewrittenFiles) + .allMatch( + file -> + ((PartitionData) file.partition()) + .getPartitionType() + .equals(table.specs().get(outputSpecId).partitionType())); + } + + protected List currentDataFiles(Table table) { + return Streams.stream(table.newScan().planFiles()) + .map(FileScanTask::file) + .collect(Collectors.toList()); + } + + private Stream toGroupStream(Table table, RewriteDataFilesSparkAction rewrite) { + rewrite.validateAndInitOptions(); + StructLikeMap>> fileGroupsByPartition = + rewrite.planFileGroups(table.currentSnapshot().snapshotId()); + + return rewrite.toGroupStream( + new RewriteExecutionContext(fileGroupsByPartition), fileGroupsByPartition); + } + + protected List currentData() { + return rowsToJava( + spark.read().format("iceberg").load(tableLocation).sort("c1", "c2", "c3").collectAsList()); + } + + protected long testDataSize(Table table) { + return Streams.stream(table.newScan().planFiles()).mapToLong(FileScanTask::length).sum(); + } + + protected void shouldHaveMultipleFiles(Table table) { + table.refresh(); + int numFiles = Iterables.size(table.newScan().planFiles()); + assertThat(numFiles) + .as(String.format("Should have multiple files, had %d", numFiles)) + .isGreaterThan(1); + } + + protected void shouldHaveFiles(Table table, int numExpected) { + table.refresh(); + int numFiles = Iterables.size(table.newScan().planFiles()); + assertThat(numFiles).as("Did not have the expected number of files").isEqualTo(numExpected); + } + + protected long shouldHaveMinSequenceNumberInPartition( + Table table, String partitionFilter, long expected) { + long actual = + SparkTableUtil.loadMetadataTable(spark, table, MetadataTableType.ENTRIES) + .filter("status != 2") + .filter(partitionFilter) + .select("sequence_number") + .agg(min("sequence_number")) + .as(Encoders.LONG()) + .collectAsList() + .get(0); + assertThat(actual).as("Did not have the expected min sequence number").isEqualTo(expected); + return actual; + } + + protected void shouldHaveSnapshots(Table table, int expectedSnapshots) { + table.refresh(); + int actualSnapshots = Iterables.size(table.snapshots()); + assertThat(actualSnapshots) + .as("Table did not have the expected number of snapshots") + .isEqualTo(expectedSnapshots); + } + + protected void shouldHaveNoOrphans(Table table) { + assertThat( + actions() + .deleteOrphanFiles(table) + .olderThan(System.currentTimeMillis()) + .execute() + .orphanFileLocations()) + .as("Should not have found any orphan files") + .isEmpty(); + } + + protected void shouldHaveOrphans(Table table) { + assertThat( + actions() + .deleteOrphanFiles(table) + .olderThan(System.currentTimeMillis()) + .execute() + .orphanFileLocations()) + .as("Should have found orphan files") + .isNotEmpty(); + } + + protected void shouldHaveACleanCache(Table table) { + assertThat(cacheContents(table)).as("Should not have any entries in cache").isEmpty(); + } + + protected void shouldHaveLastCommitSorted(Table table, String column) { + List, Pair>> overlappingFiles = checkForOverlappingFiles(table, column); + + assertThat(overlappingFiles).as("Found overlapping files").isEmpty(); + } + + protected void shouldHaveLastCommitUnsorted(Table table, String column) { + List, Pair>> overlappingFiles = checkForOverlappingFiles(table, column); + + assertThat(overlappingFiles).as("Found no overlapping files").isNotEmpty(); + } + + private Pair boundsOf(DataFile file, NestedField field, Class javaClass) { + int columnId = field.fieldId(); + return Pair.of( + javaClass.cast(Conversions.fromByteBuffer(field.type(), file.lowerBounds().get(columnId))), + javaClass.cast(Conversions.fromByteBuffer(field.type(), file.upperBounds().get(columnId)))); + } + + private List, Pair>> checkForOverlappingFiles( + Table table, String column) { + table.refresh(); + NestedField field = table.schema().caseInsensitiveFindField(column); + Class javaClass = (Class) field.type().typeId().javaClass(); + + Snapshot snapshot = table.currentSnapshot(); + Map> filesByPartition = + Streams.stream(snapshot.addedDataFiles(table.io())) + .collect(Collectors.groupingBy(DataFile::partition)); + + Stream, Pair>> overlaps = + filesByPartition.entrySet().stream() + .flatMap( + entry -> { + List datafiles = entry.getValue(); + Preconditions.checkArgument( + datafiles.size() > 1, + "This test is checking for overlaps in a situation where no overlaps can actually occur because the " + + "partition %s does not contain multiple datafiles", + entry.getKey()); + + List, Pair>> boundComparisons = + Lists.cartesianProduct(datafiles, datafiles).stream() + .filter(tuple -> tuple.get(0) != tuple.get(1)) + .map( + tuple -> + Pair.of( + boundsOf(tuple.get(0), field, javaClass), + boundsOf(tuple.get(1), field, javaClass))) + .collect(Collectors.toList()); + + Comparator comparator = Comparators.forType(field.type().asPrimitiveType()); + + List, Pair>> overlappingFiles = + boundComparisons.stream() + .filter( + filePair -> { + Pair left = filePair.first(); + T lMin = left.first(); + T lMax = left.second(); + Pair right = filePair.second(); + T rMin = right.first(); + T rMax = right.second(); + boolean boundsDoNotOverlap = + // Min and Max of a range are greater than or equal to the max + // value of the other range + (comparator.compare(rMax, lMax) >= 0 + && comparator.compare(rMin, lMax) >= 0) + || (comparator.compare(lMax, rMax) >= 0 + && comparator.compare(lMin, rMax) >= 0); + + return !boundsDoNotOverlap; + }) + .collect(Collectors.toList()); + return overlappingFiles.stream(); + }); + + return overlaps.collect(Collectors.toList()); + } + + protected Table createTable() { + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = + ImmutableMap.of(TableProperties.FORMAT_VERSION, String.valueOf(formatVersion)); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + table + .updateProperties() + .set(TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES, Integer.toString(20 * 1024)) + .commit(); + assertThat(table.currentSnapshot()).as("Table must be empty").isNull(); + return table; + } + + /** + * Create a table with a certain number of files, returns the size of a file + * + * @param files number of files to create + * @return the created table + */ + protected Table createTable(int files) { + Table table = createTable(); + writeRecords(files, SCALE); + return table; + } + + protected Table createTablePartitioned( + int partitions, int files, int numRecords, Map options) { + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("c1").truncate("c2", 2).build(); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + assertThat(table.currentSnapshot()).as("Table must be empty").isNull(); + + writeRecords(files, numRecords, partitions); + return table; + } + + protected Table createTablePartitioned(int partitions, int files) { + return createTablePartitioned( + partitions, + files, + SCALE, + ImmutableMap.of(TableProperties.FORMAT_VERSION, String.valueOf(formatVersion))); + } + + protected Table createTablePartitioned(int partitions, int files, int numRecords) { + return createTablePartitioned( + partitions, + files, + numRecords, + ImmutableMap.of(TableProperties.FORMAT_VERSION, String.valueOf(formatVersion))); + } + + private Table createTypeTestTable() { + Schema schema = + new Schema( + required(1, "longCol", Types.LongType.get()), + required(2, "intCol", Types.IntegerType.get()), + required(3, "floatCol", Types.FloatType.get()), + optional(4, "doubleCol", Types.DoubleType.get()), + optional(5, "dateCol", Types.DateType.get()), + optional(6, "timestampCol", Types.TimestampType.withZone()), + optional(7, "stringCol", Types.StringType.get()), + optional(8, "booleanCol", Types.BooleanType.get()), + optional(9, "binaryCol", Types.BinaryType.get())); + + Map options = + ImmutableMap.of(TableProperties.FORMAT_VERSION, String.valueOf(formatVersion)); + Table table = TABLES.create(schema, PartitionSpec.unpartitioned(), options, tableLocation); + + spark + .range(0, 10, 1, 10) + .withColumnRenamed("id", "longCol") + .withColumn("intCol", expr("CAST(longCol AS INT)")) + .withColumn("floatCol", expr("CAST(longCol AS FLOAT)")) + .withColumn("doubleCol", expr("CAST(longCol AS DOUBLE)")) + .withColumn("dateCol", date_add(current_date(), 1)) + .withColumn("timestampCol", expr("TO_TIMESTAMP(dateCol)")) + .withColumn("stringCol", expr("CAST(dateCol AS STRING)")) + .withColumn("booleanCol", expr("longCol > 5")) + .withColumn("binaryCol", expr("CAST(longCol AS BINARY)")) + .write() + .format("iceberg") + .mode("append") + .save(tableLocation); + + return table; + } + + protected int averageFileSize(Table table) { + table.refresh(); + return (int) + Streams.stream(table.newScan().planFiles()) + .mapToLong(FileScanTask::length) + .average() + .getAsDouble(); + } + + private void writeRecords(List records) { + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class); + writeDF(df); + } + + private void writeRecords(int files, int numRecords) { + writeRecords(files, numRecords, 0); + } + + private void writeRecords(int files, int numRecords, int partitions) { + List records = Lists.newArrayList(); + int rowDimension = (int) Math.ceil(Math.sqrt(numRecords)); + List> data = + IntStream.range(0, rowDimension) + .boxed() + .flatMap(x -> IntStream.range(0, rowDimension).boxed().map(y -> Pair.of(x, y))) + .collect(Collectors.toList()); + Collections.shuffle(data, new Random(42)); + if (partitions > 0) { + data.forEach( + i -> + records.add( + new ThreeColumnRecord( + i.first() % partitions, "foo" + i.first(), "bar" + i.second()))); + } else { + data.forEach( + i -> + records.add(new ThreeColumnRecord(i.first(), "foo" + i.first(), "bar" + i.second()))); + } + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class).repartition(files); + writeDF(df); + } + + private void writeDF(Dataset df) { + df.select("c1", "c2", "c3") + .sortWithinPartitions("c1", "c2") + .write() + .format("iceberg") + .mode("append") + .option(SparkWriteOptions.USE_TABLE_DISTRIBUTION_AND_ORDERING, "false") + .save(tableLocation); + } + + private List writePosDeletesToFile( + Table table, DataFile dataFile, int outputDeleteFiles) { + return writePosDeletes( + table, dataFile.partition(), dataFile.path().toString(), outputDeleteFiles); + } + + private List writePosDeletes( + Table table, StructLike partition, String path, int outputDeleteFiles) { + List results = Lists.newArrayList(); + int rowPosition = 0; + for (int file = 0; file < outputDeleteFiles; file++) { + OutputFile outputFile = + table + .io() + .newOutputFile( + table + .locationProvider() + .newDataLocation( + FileFormat.PARQUET.addExtension(UUID.randomUUID().toString()))); + EncryptedOutputFile encryptedOutputFile = + EncryptedFiles.encryptedOutput(outputFile, EncryptionKeyMetadata.EMPTY); + + GenericAppenderFactory appenderFactory = + new GenericAppenderFactory(table.schema(), table.spec(), null, null, null); + PositionDeleteWriter posDeleteWriter = + appenderFactory + .set(TableProperties.DEFAULT_WRITE_METRICS_MODE, "full") + .newPosDeleteWriter(encryptedOutputFile, FileFormat.PARQUET, partition); + + PositionDelete posDelete = PositionDelete.create(); + posDeleteWriter.write(posDelete.set(path, rowPosition, null)); + try { + posDeleteWriter.close(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + + results.add(posDeleteWriter.toDeleteFile()); + rowPosition++; + } + + return results; + } + + private List writeDV( + Table table, StructLike partition, String path, int numPositionsToDelete) throws IOException { + OutputFileFactory fileFactory = + OutputFileFactory.builderFor(table, 1, 1).format(FileFormat.PUFFIN).build(); + DVFileWriter writer = new BaseDVFileWriter(fileFactory, p -> null); + try (DVFileWriter closeableWriter = writer) { + for (int row = 0; row < numPositionsToDelete; row++) { + closeableWriter.delete(path, row, table.spec(), partition); + } + } + + return writer.result().deleteFiles(); + } + + private void writeEqDeleteRecord( + Table table, String partCol, Object partVal, String delCol, Object delVal) { + List equalityFieldIds = Lists.newArrayList(table.schema().findField(delCol).fieldId()); + Schema eqDeleteRowSchema = table.schema().select(delCol); + Record partitionRecord = + GenericRecord.create(table.schema().select(partCol)) + .copy(ImmutableMap.of(partCol, partVal)); + Record record = GenericRecord.create(eqDeleteRowSchema).copy(ImmutableMap.of(delCol, delVal)); + writeEqDeleteRecord(table, equalityFieldIds, partitionRecord, eqDeleteRowSchema, record); + } + + private void writeEqDeleteRecord( + Table table, + List equalityFieldIds, + Record partitionRecord, + Schema eqDeleteRowSchema, + Record deleteRecord) { + OutputFileFactory fileFactory = + OutputFileFactory.builderFor(table, 1, 1).format(FileFormat.PARQUET).build(); + GenericAppenderFactory appenderFactory = + new GenericAppenderFactory( + table.schema(), + table.spec(), + ArrayUtil.toIntArray(equalityFieldIds), + eqDeleteRowSchema, + null); + + EncryptedOutputFile file = + createEncryptedOutputFile(createPartitionKey(table, partitionRecord), fileFactory); + + EqualityDeleteWriter eqDeleteWriter = + appenderFactory.newEqDeleteWriter( + file, FileFormat.PARQUET, createPartitionKey(table, partitionRecord)); + + try (EqualityDeleteWriter clsEqDeleteWriter = eqDeleteWriter) { + clsEqDeleteWriter.write(deleteRecord); + } catch (Exception e) { + throw new RuntimeException(e); + } + table.newRowDelta().addDeletes(eqDeleteWriter.toDeleteFile()).commit(); + } + + private PartitionKey createPartitionKey(Table table, Record record) { + if (table.spec().isUnpartitioned()) { + return null; + } + + PartitionKey partitionKey = new PartitionKey(table.spec(), table.schema()); + partitionKey.partition(record); + + return partitionKey; + } + + private EncryptedOutputFile createEncryptedOutputFile( + PartitionKey partition, OutputFileFactory fileFactory) { + if (partition == null) { + return fileFactory.newOutputFile(); + } else { + return fileFactory.newOutputFile(partition); + } + } + + private SparkActions actions() { + return SparkActions.get(); + } + + private Set cacheContents(Table table) { + return ImmutableSet.builder() + .addAll(manager.fetchSetIds(table)) + .addAll(coordinator.fetchSetIds(table)) + .build(); + } + + private double percentFilesRequired(Table table, String col, String value) { + return percentFilesRequired(table, new String[] {col}, new String[] {value}); + } + + private double percentFilesRequired(Table table, String[] cols, String[] values) { + Preconditions.checkArgument(cols.length == values.length); + Expression restriction = Expressions.alwaysTrue(); + for (int i = 0; i < cols.length; i++) { + restriction = Expressions.and(restriction, Expressions.equal(cols[i], values[i])); + } + int totalFiles = Iterables.size(table.newScan().planFiles()); + int filteredFiles = Iterables.size(table.newScan().filter(restriction).planFiles()); + return (double) filteredFiles / (double) totalFiles; + } + + class GroupInfoMatcher implements ArgumentMatcher { + private final Set groupIDs; + + GroupInfoMatcher(Integer... globalIndex) { + this.groupIDs = ImmutableSet.copyOf(globalIndex); + } + + @Override + public boolean matches(RewriteFileGroup argument) { + return groupIDs.contains(argument.info().globalIndex()); + } + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestRewriteManifestsAction.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestRewriteManifestsAction.java new file mode 100644 index 000000000000..44971843547b --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestRewriteManifestsAction.java @@ -0,0 +1,1202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.apache.iceberg.ValidationHelpers.dataSeqs; +import static org.apache.iceberg.ValidationHelpers.fileSeqs; +import static org.apache.iceberg.ValidationHelpers.files; +import static org.apache.iceberg.ValidationHelpers.snapshotIds; +import static org.apache.iceberg.ValidationHelpers.validateDataManifest; +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assumptions.assumeThat; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ThreadLocalRandom; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileGenerationUtil; +import org.apache.iceberg.FileMetadata; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Files; +import org.apache.iceberg.ManifestContent; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.ManifestFiles; +import org.apache.iceberg.ManifestWriter; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.RowDelta; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.TestHelpers; +import org.apache.iceberg.actions.RewriteManifests; +import org.apache.iceberg.data.FileHelpers; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.exceptions.CommitStateUnknownException; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkTableUtil; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.spark.TestBase; +import org.apache.iceberg.spark.source.ThreeColumnRecord; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.CharSequenceSet; +import org.apache.iceberg.util.Pair; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.TableIdentifier; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestRewriteManifestsAction extends TestBase { + + private static final HadoopTables TABLES = new HadoopTables(new Configuration()); + private static final Schema SCHEMA = + new Schema( + optional(1, "c1", Types.IntegerType.get()), + optional(2, "c2", Types.StringType.get()), + optional(3, "c3", Types.StringType.get())); + + @Parameters( + name = + "snapshotIdInheritanceEnabled = {0}, useCaching = {1}, shouldStageManifests = {2}, formatVersion = {3}") + public static Object[] parameters() { + return new Object[][] { + new Object[] {"true", "true", false, 1}, + new Object[] {"false", "true", true, 1}, + new Object[] {"true", "false", false, 2}, + new Object[] {"false", "false", false, 2}, + new Object[] {"true", "false", false, 3}, + new Object[] {"false", "false", false, 3} + }; + } + + @Parameter private String snapshotIdInheritanceEnabled; + + @Parameter(index = 1) + private String useCaching; + + @Parameter(index = 2) + private boolean shouldStageManifests; + + @Parameter(index = 3) + private int formatVersion; + + private String tableLocation = null; + + @TempDir private Path temp; + @TempDir private File tableDir; + + @BeforeEach + public void setupTableLocation() throws Exception { + this.tableLocation = tableDir.toURI().toString(); + } + + @TestTemplate + public void testRewriteManifestsPreservesOptionalFields() throws IOException { + assumeThat(formatVersion).isGreaterThanOrEqualTo(2); + + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("c1").build(); + Map options = Maps.newHashMap(); + options.put(TableProperties.FORMAT_VERSION, String.valueOf(formatVersion)); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + DataFile dataFile1 = newDataFile(table, "c1=0"); + DataFile dataFile2 = newDataFile(table, "c1=0"); + DataFile dataFile3 = newDataFile(table, "c1=0"); + table + .newFastAppend() + .appendFile(dataFile1) + .appendFile(dataFile2) + .appendFile(dataFile3) + .commit(); + + DeleteFile deleteFile1 = newDeletes(table, dataFile1); + assertDeletes(dataFile1, deleteFile1); + table.newRowDelta().addDeletes(deleteFile1).commit(); + + DeleteFile deleteFile2 = newDeletes(table, dataFile2); + assertDeletes(dataFile2, deleteFile2); + table.newRowDelta().addDeletes(deleteFile2).commit(); + + DeleteFile deleteFile3 = newDeletes(table, dataFile3); + assertDeletes(dataFile3, deleteFile3); + table.newRowDelta().addDeletes(deleteFile3).commit(); + + SparkActions actions = SparkActions.get(); + + actions + .rewriteManifests(table) + .rewriteIf(manifest -> true) + .option(RewriteManifestsSparkAction.USE_CACHING, useCaching) + .execute(); + + table.refresh(); + + try (CloseableIterable tasks = table.newScan().planFiles()) { + for (FileScanTask fileTask : tasks) { + DataFile dataFile = fileTask.file(); + DeleteFile deleteFile = Iterables.getOnlyElement(fileTask.deletes()); + if (dataFile.location().equals(dataFile1.location())) { + assertThat(deleteFile.referencedDataFile()).isEqualTo(deleteFile1.referencedDataFile()); + assertEqual(deleteFile, deleteFile1); + } else if (dataFile.location().equals(dataFile2.location())) { + assertThat(deleteFile.referencedDataFile()).isEqualTo(deleteFile2.referencedDataFile()); + assertEqual(deleteFile, deleteFile2); + } else { + assertThat(deleteFile.referencedDataFile()).isEqualTo(deleteFile3.referencedDataFile()); + assertEqual(deleteFile, deleteFile3); + } + } + } + } + + @TestTemplate + public void testRewriteManifestsEmptyTable() throws IOException { + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + options.put(TableProperties.FORMAT_VERSION, String.valueOf(formatVersion)); + options.put(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, snapshotIdInheritanceEnabled); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + assertThat(table.currentSnapshot()).as("Table must be empty").isNull(); + + SparkActions actions = SparkActions.get(); + + actions + .rewriteManifests(table) + .rewriteIf(manifest -> true) + .option(RewriteManifestsSparkAction.USE_CACHING, useCaching) + .stagingLocation(java.nio.file.Files.createTempDirectory(temp, "junit").toString()) + .execute(); + + assertThat(table.currentSnapshot()).as("Table must stay empty").isNull(); + } + + @TestTemplate + public void testRewriteSmallManifestsNonPartitionedTable() { + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + options.put(TableProperties.FORMAT_VERSION, String.valueOf(formatVersion)); + options.put(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, snapshotIdInheritanceEnabled); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + List records1 = + Lists.newArrayList( + new ThreeColumnRecord(1, null, "AAAA"), new ThreeColumnRecord(1, "BBBBBBBBBB", "BBBB")); + writeRecords(records1); + + List records2 = + Lists.newArrayList( + new ThreeColumnRecord(2, "CCCCCCCCCC", "CCCC"), + new ThreeColumnRecord(2, "DDDDDDDDDD", "DDDD")); + writeRecords(records2); + + table.refresh(); + + List manifests = table.currentSnapshot().allManifests(table.io()); + assertThat(manifests).as("Should have 2 manifests before rewrite").hasSize(2); + + SparkActions actions = SparkActions.get(); + + RewriteManifests.Result result = + actions + .rewriteManifests(table) + .rewriteIf(manifest -> true) + .option(RewriteManifestsSparkAction.USE_CACHING, useCaching) + .execute(); + + assertThat(result.rewrittenManifests()).as("Action should rewrite 2 manifests").hasSize(2); + assertThat(result.addedManifests()).as("Action should add 1 manifests").hasSize(1); + assertManifestsLocation(result.addedManifests()); + + table.refresh(); + + List newManifests = table.currentSnapshot().allManifests(table.io()); + assertThat(newManifests).as("Should have 1 manifests after rewrite").hasSize(1); + + assertThat(newManifests.get(0).existingFilesCount()).isEqualTo(4); + assertThat(newManifests.get(0).hasAddedFiles()).isFalse(); + assertThat(newManifests.get(0).hasDeletedFiles()).isFalse(); + + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(records1); + expectedRecords.addAll(records2); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = + resultDF.sort("c1", "c2").as(Encoders.bean(ThreeColumnRecord.class)).collectAsList(); + + assertThat(actualRecords).as("Rows must match").isEqualTo(expectedRecords); + } + + @TestTemplate + public void testRewriteManifestsWithCommitStateUnknownException() { + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + options.put(TableProperties.FORMAT_VERSION, String.valueOf(formatVersion)); + options.put(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, snapshotIdInheritanceEnabled); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + List records1 = + Lists.newArrayList( + new ThreeColumnRecord(1, null, "AAAA"), new ThreeColumnRecord(1, "BBBBBBBBBB", "BBBB")); + writeRecords(records1); + + List records2 = + Lists.newArrayList( + new ThreeColumnRecord(2, "CCCCCCCCCC", "CCCC"), + new ThreeColumnRecord(2, "DDDDDDDDDD", "DDDD")); + writeRecords(records2); + + table.refresh(); + + List manifests = table.currentSnapshot().allManifests(table.io()); + assertThat(manifests).as("Should have 2 manifests before rewrite").hasSize(2); + + SparkActions actions = SparkActions.get(); + + // create a spy which would throw a CommitStateUnknownException after successful commit. + org.apache.iceberg.RewriteManifests newRewriteManifests = table.rewriteManifests(); + org.apache.iceberg.RewriteManifests spyNewRewriteManifests = spy(newRewriteManifests); + doAnswer( + invocation -> { + newRewriteManifests.commit(); + throw new CommitStateUnknownException(new RuntimeException("Datacenter on Fire")); + }) + .when(spyNewRewriteManifests) + .commit(); + + Table spyTable = spy(table); + when(spyTable.rewriteManifests()).thenReturn(spyNewRewriteManifests); + + assertThatThrownBy( + () -> actions.rewriteManifests(spyTable).rewriteIf(manifest -> true).execute()) + .cause() + .isInstanceOf(RuntimeException.class) + .hasMessage("Datacenter on Fire"); + + table.refresh(); + + // table should reflect the changes, since the commit was successful + List newManifests = table.currentSnapshot().allManifests(table.io()); + assertThat(newManifests).as("Should have 1 manifests after rewrite").hasSize(1); + + assertThat(newManifests.get(0).existingFilesCount()).isEqualTo(4); + assertThat(newManifests.get(0).hasAddedFiles()).isFalse(); + assertThat(newManifests.get(0).hasDeletedFiles()).isFalse(); + + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(records1); + expectedRecords.addAll(records2); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = + resultDF.sort("c1", "c2").as(Encoders.bean(ThreeColumnRecord.class)).collectAsList(); + + assertThat(actualRecords).as("Rows must match").isEqualTo(expectedRecords); + } + + @TestTemplate + public void testRewriteSmallManifestsPartitionedTable() { + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("c1").truncate("c2", 2).build(); + Map options = Maps.newHashMap(); + options.put(TableProperties.FORMAT_VERSION, String.valueOf(formatVersion)); + options.put(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, snapshotIdInheritanceEnabled); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + List records1 = + Lists.newArrayList( + new ThreeColumnRecord(1, null, "AAAA"), new ThreeColumnRecord(1, "BBBBBBBBBB", "BBBB")); + writeRecords(records1); + + List records2 = + Lists.newArrayList( + new ThreeColumnRecord(2, "CCCCCCCCCC", "CCCC"), + new ThreeColumnRecord(2, "DDDDDDDDDD", "DDDD")); + writeRecords(records2); + + List records3 = + Lists.newArrayList( + new ThreeColumnRecord(3, "EEEEEEEEEE", "EEEE"), + new ThreeColumnRecord(3, "FFFFFFFFFF", "FFFF")); + writeRecords(records3); + + List records4 = + Lists.newArrayList( + new ThreeColumnRecord(4, "GGGGGGGGGG", "GGGG"), + new ThreeColumnRecord(4, "HHHHHHHHHG", "HHHH")); + writeRecords(records4); + + table.refresh(); + + List manifests = table.currentSnapshot().allManifests(table.io()); + assertThat(manifests).as("Should have 4 manifests before rewrite").hasSize(4); + + SparkActions actions = SparkActions.get(); + + // we will expect to have 2 manifests with 4 entries in each after rewrite + long manifestEntrySizeBytes = computeManifestEntrySizeBytes(manifests); + long targetManifestSizeBytes = (long) (1.05 * 4 * manifestEntrySizeBytes); + + table + .updateProperties() + .set(TableProperties.MANIFEST_TARGET_SIZE_BYTES, String.valueOf(targetManifestSizeBytes)) + .commit(); + + RewriteManifests.Result result = + actions + .rewriteManifests(table) + .rewriteIf(manifest -> true) + .option(RewriteManifestsSparkAction.USE_CACHING, useCaching) + .execute(); + + assertThat(result.rewrittenManifests()).as("Action should rewrite 4 manifests").hasSize(4); + assertThat(result.addedManifests()).as("Action should add 2 manifests").hasSize(2); + assertManifestsLocation(result.addedManifests()); + + table.refresh(); + + List newManifests = table.currentSnapshot().allManifests(table.io()); + + assertThat(newManifests).as("Should have 2 manifests after rewrite").hasSize(2); + + assertThat(newManifests.get(0).existingFilesCount()).isEqualTo(4); + assertThat(newManifests.get(0).hasAddedFiles()).isFalse(); + assertThat(newManifests.get(0).hasDeletedFiles()).isFalse(); + + assertThat(newManifests.get(1).existingFilesCount()).isEqualTo(4); + assertThat(newManifests.get(1).hasAddedFiles()).isFalse(); + assertThat(newManifests.get(1).hasDeletedFiles()).isFalse(); + + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(records1); + expectedRecords.addAll(records2); + expectedRecords.addAll(records3); + expectedRecords.addAll(records4); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = + resultDF.sort("c1", "c2").as(Encoders.bean(ThreeColumnRecord.class)).collectAsList(); + + assertThat(actualRecords).as("Rows must match").isEqualTo(expectedRecords); + } + + @TestTemplate + public void testRewriteImportedManifests() throws IOException { + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("c3").build(); + Map options = Maps.newHashMap(); + options.put(TableProperties.FORMAT_VERSION, String.valueOf(formatVersion)); + options.put(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, snapshotIdInheritanceEnabled); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + List records = + Lists.newArrayList( + new ThreeColumnRecord(1, null, "AAAA"), new ThreeColumnRecord(1, "BBBBBBBBBB", "BBBB")); + File parquetTableDir = temp.resolve("parquet_table").toFile(); + String parquetTableLocation = parquetTableDir.toURI().toString(); + + try { + Dataset inputDF = spark.createDataFrame(records, ThreeColumnRecord.class); + inputDF + .select("c1", "c2", "c3") + .write() + .format("parquet") + .mode("overwrite") + .option("path", parquetTableLocation) + .partitionBy("c3") + .saveAsTable("parquet_table"); + + File stagingDir = temp.resolve("staging-dir").toFile(); + SparkTableUtil.importSparkTable( + spark, new TableIdentifier("parquet_table"), table, stagingDir.toString()); + + // add some more data to create more than one manifest for the rewrite + inputDF.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + table.refresh(); + + Snapshot snapshot = table.currentSnapshot(); + + SparkActions actions = SparkActions.get(); + + String rewriteStagingLocation = + java.nio.file.Files.createTempDirectory(temp, "junit").toString(); + + RewriteManifests.Result result = + actions + .rewriteManifests(table) + .rewriteIf(manifest -> true) + .option(RewriteManifestsSparkAction.USE_CACHING, useCaching) + .stagingLocation(rewriteStagingLocation) + .execute(); + + assertThat(result.rewrittenManifests()) + .as("Action should rewrite all manifests") + .isEqualTo(snapshot.allManifests(table.io())); + assertThat(result.addedManifests()).as("Action should add 1 manifest").hasSize(1); + assertManifestsLocation(result.addedManifests(), rewriteStagingLocation); + + } finally { + spark.sql("DROP TABLE parquet_table"); + } + } + + @TestTemplate + public void testRewriteLargeManifestsPartitionedTable() throws IOException { + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("c3").build(); + Map options = Maps.newHashMap(); + options.put(TableProperties.FORMAT_VERSION, String.valueOf(formatVersion)); + options.put(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, snapshotIdInheritanceEnabled); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + List dataFiles = Lists.newArrayList(); + for (int fileOrdinal = 0; fileOrdinal < 1000; fileOrdinal++) { + dataFiles.add(newDataFile(table, "c3=" + fileOrdinal)); + } + ManifestFile appendManifest = writeManifest(table, dataFiles); + table.newFastAppend().appendManifest(appendManifest).commit(); + + List manifests = table.currentSnapshot().allManifests(table.io()); + assertThat(manifests).as("Should have 1 manifests before rewrite").hasSize(1); + + // set the target manifest size to a small value to force splitting records into multiple files + table + .updateProperties() + .set( + TableProperties.MANIFEST_TARGET_SIZE_BYTES, + String.valueOf(manifests.get(0).length() / 2)) + .commit(); + + SparkActions actions = SparkActions.get(); + + String stagingLocation = java.nio.file.Files.createTempDirectory(temp, "junit").toString(); + + RewriteManifests.Result result = + actions + .rewriteManifests(table) + .rewriteIf(manifest -> true) + .option(RewriteManifestsSparkAction.USE_CACHING, useCaching) + .stagingLocation(stagingLocation) + .execute(); + + assertThat(result.rewrittenManifests()).hasSize(1); + assertThat(result.addedManifests()).hasSizeGreaterThanOrEqualTo(2); + assertManifestsLocation(result.addedManifests(), stagingLocation); + + table.refresh(); + + List newManifests = table.currentSnapshot().allManifests(table.io()); + assertThat(newManifests).hasSizeGreaterThanOrEqualTo(2); + } + + @TestTemplate + public void testRewriteManifestsWithPredicate() throws IOException { + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("c1").truncate("c2", 2).build(); + Map options = Maps.newHashMap(); + options.put(TableProperties.FORMAT_VERSION, String.valueOf(formatVersion)); + options.put(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, snapshotIdInheritanceEnabled); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + List records1 = + Lists.newArrayList( + new ThreeColumnRecord(1, null, "AAAA"), new ThreeColumnRecord(1, "BBBBBBBBBB", "BBBB")); + writeRecords(records1); + + writeRecords(records1); + + List records2 = + Lists.newArrayList( + new ThreeColumnRecord(2, "CCCCCCCCCC", "CCCC"), + new ThreeColumnRecord(2, "DDDDDDDDDD", "DDDD")); + writeRecords(records2); + + table.refresh(); + + List manifests = table.currentSnapshot().allManifests(table.io()); + assertThat(manifests).as("Should have 3 manifests before rewrite").hasSize(3); + + SparkActions actions = SparkActions.get(); + + String stagingLocation = java.nio.file.Files.createTempDirectory(temp, "junit").toString(); + + // rewrite only the first manifest + RewriteManifests.Result result = + actions + .rewriteManifests(table) + .rewriteIf( + manifest -> + (manifest.path().equals(manifests.get(0).path()) + || (manifest.path().equals(manifests.get(1).path())))) + .stagingLocation(stagingLocation) + .option(RewriteManifestsSparkAction.USE_CACHING, useCaching) + .execute(); + + assertThat(result.rewrittenManifests()).as("Action should rewrite 2 manifest").hasSize(2); + assertThat(result.addedManifests()).as("Action should add 1 manifests").hasSize(1); + assertManifestsLocation(result.addedManifests(), stagingLocation); + + table.refresh(); + + List newManifests = table.currentSnapshot().allManifests(table.io()); + assertThat(newManifests) + .as("Should have 2 manifests after rewrite") + .hasSize(2) + .as("First manifest must be rewritten") + .doesNotContain(manifests.get(0)) + .as("Second manifest must be rewritten") + .doesNotContain(manifests.get(1)) + .as("Third manifest must not be rewritten") + .contains(manifests.get(2)); + + List expectedRecords = Lists.newArrayList(); + expectedRecords.add(records1.get(0)); + expectedRecords.add(records1.get(0)); + expectedRecords.add(records1.get(1)); + expectedRecords.add(records1.get(1)); + expectedRecords.addAll(records2); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = + resultDF.sort("c1", "c2").as(Encoders.bean(ThreeColumnRecord.class)).collectAsList(); + + assertThat(actualRecords).as("Rows must match").isEqualTo(expectedRecords); + } + + @TestTemplate + public void testRewriteSmallManifestsNonPartitionedV2Table() { + assumeThat(formatVersion).isGreaterThan(1); + + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map properties = ImmutableMap.of(TableProperties.FORMAT_VERSION, "2"); + Table table = TABLES.create(SCHEMA, spec, properties, tableLocation); + + List records1 = Lists.newArrayList(new ThreeColumnRecord(1, null, "AAAA")); + writeRecords(records1); + + table.refresh(); + + Snapshot snapshot1 = table.currentSnapshot(); + DataFile file1 = Iterables.getOnlyElement(snapshot1.addedDataFiles(table.io())); + + List records2 = Lists.newArrayList(new ThreeColumnRecord(2, "CCCC", "CCCC")); + writeRecords(records2); + + table.refresh(); + + Snapshot snapshot2 = table.currentSnapshot(); + DataFile file2 = Iterables.getOnlyElement(snapshot2.addedDataFiles(table.io())); + + List manifests = table.currentSnapshot().allManifests(table.io()); + assertThat(manifests).as("Should have 2 manifests before rewrite").hasSize(2); + + SparkActions actions = SparkActions.get(); + RewriteManifests.Result result = + actions + .rewriteManifests(table) + .option(RewriteManifestsSparkAction.USE_CACHING, useCaching) + .execute(); + assertThat(result.rewrittenManifests()).as("Action should rewrite 2 manifests").hasSize(2); + assertThat(result.addedManifests()).as("Action should add 1 manifests").hasSize(1); + assertManifestsLocation(result.addedManifests()); + + table.refresh(); + + List newManifests = table.currentSnapshot().allManifests(table.io()); + assertThat(newManifests).as("Should have 1 manifests after rewrite").hasSize(1); + + ManifestFile newManifest = Iterables.getOnlyElement(newManifests); + assertThat(newManifest.existingFilesCount()).isEqualTo(2); + assertThat(newManifest.hasAddedFiles()).isFalse(); + assertThat(newManifest.hasDeletedFiles()).isFalse(); + + validateDataManifest( + table, + newManifest, + dataSeqs(1L, 2L), + fileSeqs(1L, 2L), + snapshotIds(snapshot1.snapshotId(), snapshot2.snapshotId()), + files(file1, file2)); + + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(records1); + expectedRecords.addAll(records2); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = + resultDF.sort("c1", "c2").as(Encoders.bean(ThreeColumnRecord.class)).collectAsList(); + + assertThat(actualRecords).as("Rows must match").isEqualTo(expectedRecords); + } + + @TestTemplate + public void testRewriteLargeManifestsEvolvedUnpartitionedV1Table() throws IOException { + assumeThat(formatVersion).isEqualTo(1); + + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("c3").build(); + Map options = Maps.newHashMap(); + options.put(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, snapshotIdInheritanceEnabled); + options.put(TableProperties.FORMAT_VERSION, String.valueOf(formatVersion)); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + table.updateSpec().removeField("c3").commit(); + + assertThat(table.spec().fields()).hasSize(1).allMatch(field -> field.transform().isVoid()); + + List dataFiles = Lists.newArrayList(); + for (int fileOrdinal = 0; fileOrdinal < 1000; fileOrdinal++) { + dataFiles.add(newDataFile(table, TestHelpers.Row.of(new Object[] {null}))); + } + ManifestFile appendManifest = writeManifest(table, dataFiles); + table.newFastAppend().appendManifest(appendManifest).commit(); + + List originalManifests = table.currentSnapshot().allManifests(table.io()); + ManifestFile originalManifest = Iterables.getOnlyElement(originalManifests); + + // set the target manifest size to a small value to force splitting records into multiple files + table + .updateProperties() + .set( + TableProperties.MANIFEST_TARGET_SIZE_BYTES, + String.valueOf(originalManifest.length() / 2)) + .commit(); + + SparkActions actions = SparkActions.get(); + + RewriteManifests.Result result = + actions + .rewriteManifests(table) + .rewriteIf(manifest -> true) + .option(RewriteManifestsSparkAction.USE_CACHING, useCaching) + .execute(); + + assertThat(result.rewrittenManifests()).hasSize(1); + assertThat(result.addedManifests()).hasSizeGreaterThanOrEqualTo(2); + assertManifestsLocation(result.addedManifests()); + + List manifests = table.currentSnapshot().allManifests(table.io()); + assertThat(manifests).hasSizeGreaterThanOrEqualTo(2); + } + + @TestTemplate + public void testRewriteSmallDeleteManifestsNonPartitionedTable() throws IOException { + assumeThat(formatVersion).isEqualTo(2); + + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + options.put(TableProperties.FORMAT_VERSION, String.valueOf(formatVersion)); + options.put(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, snapshotIdInheritanceEnabled); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + // commit data records + List records = + Lists.newArrayList( + new ThreeColumnRecord(1, null, "AAAA"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "BBBB"), + new ThreeColumnRecord(3, "CCCCCCCCCC", "CCCC"), + new ThreeColumnRecord(4, "DDDDDDDDDD", "DDDD")); + writeRecords(records); + + // commit a position delete file to remove records where c1 = 1 OR c1 = 2 + List> posDeletes = generatePosDeletes("c1 = 1 OR c1 = 2"); + Pair posDeleteWriteResult = writePosDeletes(table, posDeletes); + table + .newRowDelta() + .addDeletes(posDeleteWriteResult.first()) + .validateDataFilesExist(posDeleteWriteResult.second()) + .commit(); + + // commit an equality delete file to remove all records where c1 = 3 + DeleteFile eqDeleteFile = writeEqDeletes(table, "c1", 3); + table.newRowDelta().addDeletes(eqDeleteFile).commit(); + + // the current snapshot should contain 1 data manifest and 2 delete manifests + List originalManifests = table.currentSnapshot().allManifests(table.io()); + assertThat(originalManifests).hasSize(3); + + SparkActions actions = SparkActions.get(); + + RewriteManifests.Result result = + actions + .rewriteManifests(table) + .option(RewriteManifestsSparkAction.USE_CACHING, useCaching) + .execute(); + + // the original delete manifests must be combined + assertThat(result.rewrittenManifests()) + .hasSize(2) + .allMatch(m -> m.content() == ManifestContent.DELETES); + assertThat(result.addedManifests()) + .hasSize(1) + .allMatch(m -> m.content() == ManifestContent.DELETES); + assertManifestsLocation(result.addedManifests()); + + // the new delete manifest must only contain files with status EXISTING + ManifestFile deleteManifest = + Iterables.getOnlyElement(table.currentSnapshot().deleteManifests(table.io())); + assertThat(deleteManifest.existingFilesCount()).isEqualTo(2); + assertThat(deleteManifest.hasAddedFiles()).isFalse(); + assertThat(deleteManifest.hasDeletedFiles()).isFalse(); + + // the preserved data manifest must only contain files with status ADDED + ManifestFile dataManifest = + Iterables.getOnlyElement(table.currentSnapshot().dataManifests(table.io())); + assertThat(dataManifest.hasExistingFiles()).isFalse(); + assertThat(dataManifest.hasAddedFiles()).isTrue(); + assertThat(dataManifest.hasDeletedFiles()).isFalse(); + + // the table must produce expected records after the rewrite + List expectedRecords = + Lists.newArrayList(new ThreeColumnRecord(4, "DDDDDDDDDD", "DDDD")); + assertThat(actualRecords()).isEqualTo(expectedRecords); + } + + @TestTemplate + public void testRewriteSmallDeleteManifestsPartitionedTable() throws IOException { + assumeThat(formatVersion).isEqualTo(2); + + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("c3").build(); + Map options = Maps.newHashMap(); + options.put(TableProperties.FORMAT_VERSION, String.valueOf(formatVersion)); + options.put(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, snapshotIdInheritanceEnabled); + options.put(TableProperties.MANIFEST_MERGE_ENABLED, "false"); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + // commit data records + List records = + Lists.newArrayList( + new ThreeColumnRecord(1, null, "AAAA"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "BBBB"), + new ThreeColumnRecord(3, "CCCCCCCCCC", "CCCC"), + new ThreeColumnRecord(4, "DDDDDDDDDD", "DDDD"), + new ThreeColumnRecord(5, "EEEEEEEEEE", "EEEE")); + writeRecords(records); + + // commit the first position delete file to remove records where c1 = 1 + List> posDeletes1 = generatePosDeletes("c1 = 1"); + Pair posDeleteWriteResult1 = + writePosDeletes(table, TestHelpers.Row.of("AAAA"), posDeletes1); + table + .newRowDelta() + .addDeletes(posDeleteWriteResult1.first()) + .validateDataFilesExist(posDeleteWriteResult1.second()) + .commit(); + + // commit the second position delete file to remove records where c1 = 2 + List> posDeletes2 = generatePosDeletes("c1 = 2"); + Pair positionDeleteWriteResult2 = + writePosDeletes(table, TestHelpers.Row.of("BBBB"), posDeletes2); + table + .newRowDelta() + .addDeletes(positionDeleteWriteResult2.first()) + .validateDataFilesExist(positionDeleteWriteResult2.second()) + .commit(); + + // commit the first equality delete file to remove records where c1 = 3 + DeleteFile eqDeleteFile1 = writeEqDeletes(table, TestHelpers.Row.of("CCCC"), "c1", 3); + table.newRowDelta().addDeletes(eqDeleteFile1).commit(); + + // commit the second equality delete file to remove records where c1 = 4 + DeleteFile eqDeleteFile2 = writeEqDeletes(table, TestHelpers.Row.of("DDDD"), "c1", 4); + table.newRowDelta().addDeletes(eqDeleteFile2).commit(); + + // the table must have 1 data manifest and 4 delete manifests + List originalManifests = table.currentSnapshot().allManifests(table.io()); + assertThat(originalManifests).hasSize(5); + + // set the target manifest size to have 2 manifests with 2 entries in each after the rewrite + List originalDeleteManifests = + table.currentSnapshot().deleteManifests(table.io()); + long manifestEntrySizeBytes = computeManifestEntrySizeBytes(originalDeleteManifests); + long targetManifestSizeBytes = (long) (1.05 * 2 * manifestEntrySizeBytes); + + table + .updateProperties() + .set(TableProperties.MANIFEST_TARGET_SIZE_BYTES, String.valueOf(targetManifestSizeBytes)) + .commit(); + + SparkActions actions = SparkActions.get(); + + RewriteManifests.Result result = + actions + .rewriteManifests(table) + .rewriteIf(manifest -> manifest.content() == ManifestContent.DELETES) + .option(RewriteManifestsSparkAction.USE_CACHING, useCaching) + .execute(); + + // the original 4 delete manifests must be replaced with 2 new delete manifests + assertThat(result.rewrittenManifests()) + .hasSize(4) + .allMatch(m -> m.content() == ManifestContent.DELETES); + assertThat(result.addedManifests()) + .hasSize(2) + .allMatch(m -> m.content() == ManifestContent.DELETES); + assertManifestsLocation(result.addedManifests()); + + List deleteManifests = table.currentSnapshot().deleteManifests(table.io()); + assertThat(deleteManifests).hasSize(2); + + // the first new delete manifest must only contain files with status EXISTING + ManifestFile deleteManifest1 = deleteManifests.get(0); + assertThat(deleteManifest1.existingFilesCount()).isEqualTo(2); + assertThat(deleteManifest1.hasAddedFiles()).isFalse(); + assertThat(deleteManifest1.hasDeletedFiles()).isFalse(); + + // the second new delete manifest must only contain files with status EXISTING + ManifestFile deleteManifest2 = deleteManifests.get(1); + assertThat(deleteManifest2.existingFilesCount()).isEqualTo(2); + assertThat(deleteManifest2.hasAddedFiles()).isFalse(); + assertThat(deleteManifest2.hasDeletedFiles()).isFalse(); + + // the table must produce expected records after the rewrite + List expectedRecords = + Lists.newArrayList(new ThreeColumnRecord(5, "EEEEEEEEEE", "EEEE")); + assertThat(actualRecords()).isEqualTo(expectedRecords); + } + + @TestTemplate + public void testRewriteLargeDeleteManifestsPartitionedTable() throws IOException { + assumeThat(formatVersion).isEqualTo(2); + + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("c3").build(); + Map options = Maps.newHashMap(); + options.put(TableProperties.FORMAT_VERSION, String.valueOf(formatVersion)); + options.put(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, snapshotIdInheritanceEnabled); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + // generate enough delete files to have a reasonably sized manifest + List deleteFiles = Lists.newArrayList(); + for (int fileOrdinal = 0; fileOrdinal < 1000; fileOrdinal++) { + DeleteFile deleteFile = newDeleteFile(table, "c3=" + fileOrdinal); + deleteFiles.add(deleteFile); + } + + // commit delete files + RowDelta rowDelta = table.newRowDelta(); + for (DeleteFile deleteFile : deleteFiles) { + rowDelta.addDeletes(deleteFile); + } + rowDelta.commit(); + + // the current snapshot should contain only 1 delete manifest + List originalDeleteManifests = + table.currentSnapshot().deleteManifests(table.io()); + ManifestFile originalDeleteManifest = Iterables.getOnlyElement(originalDeleteManifests); + + // set the target manifest size to a small value to force splitting records into multiple files + table + .updateProperties() + .set( + TableProperties.MANIFEST_TARGET_SIZE_BYTES, + String.valueOf(originalDeleteManifest.length() / 2)) + .commit(); + + SparkActions actions = SparkActions.get(); + + String stagingLocation = java.nio.file.Files.createTempDirectory(temp, "junit").toString(); + + RewriteManifests.Result result = + actions + .rewriteManifests(table) + .rewriteIf(manifest -> true) + .option(RewriteManifestsSparkAction.USE_CACHING, useCaching) + .stagingLocation(stagingLocation) + .execute(); + + // the action must rewrite the original delete manifest and add at least 2 new ones + assertThat(result.rewrittenManifests()) + .hasSize(1) + .allMatch(m -> m.content() == ManifestContent.DELETES); + assertThat(result.addedManifests()) + .hasSizeGreaterThanOrEqualTo(2) + .allMatch(m -> m.content() == ManifestContent.DELETES); + assertManifestsLocation(result.addedManifests(), stagingLocation); + + // the current snapshot must return the correct number of delete manifests + List deleteManifests = table.currentSnapshot().deleteManifests(table.io()); + assertThat(deleteManifests).hasSizeGreaterThanOrEqualTo(2); + } + + @TestTemplate + public void testRewriteManifestsAfterUpgradeToV3() throws IOException { + assumeThat(formatVersion).isEqualTo(2); + + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("c1").build(); + Map options = ImmutableMap.of(TableProperties.FORMAT_VERSION, "2"); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + DataFile dataFile1 = newDataFile(table, "c1=1"); + DeleteFile deleteFile1 = newDeletes(table, dataFile1); + table.newRowDelta().addRows(dataFile1).addDeletes(deleteFile1).commit(); + + DataFile dataFile2 = newDataFile(table, "c1=1"); + DeleteFile deleteFile2 = newDeletes(table, dataFile2); + table.newRowDelta().addRows(dataFile2).addDeletes(deleteFile2).commit(); + + // upgrade the table to enable DVs + table.updateProperties().set(TableProperties.FORMAT_VERSION, "3").commit(); + + DataFile dataFile3 = newDataFile(table, "c1=1"); + DeleteFile dv3 = newDV(table, dataFile3); + table.newRowDelta().addRows(dataFile3).addDeletes(dv3).commit(); + + SparkActions actions = SparkActions.get(); + + RewriteManifests.Result result = + actions + .rewriteManifests(table) + .rewriteIf(manifest -> true) + .option(RewriteManifestsSparkAction.USE_CACHING, useCaching) + .execute(); + + assertThat(result.rewrittenManifests()).as("Action should rewrite 6 manifests").hasSize(6); + assertThat(result.addedManifests()).as("Action should add 2 manifests").hasSize(2); + assertManifestsLocation(result.addedManifests()); + + table.refresh(); + + try (CloseableIterable tasks = table.newScan().planFiles()) { + for (FileScanTask fileTask : tasks) { + DataFile dataFile = fileTask.file(); + DeleteFile deleteFile = Iterables.getOnlyElement(fileTask.deletes()); + if (dataFile.location().equals(dataFile1.location())) { + assertThat(deleteFile.referencedDataFile()).isEqualTo(deleteFile1.referencedDataFile()); + assertEqual(deleteFile, deleteFile1); + } else if (dataFile.location().equals(dataFile2.location())) { + assertThat(deleteFile.referencedDataFile()).isEqualTo(deleteFile2.referencedDataFile()); + assertEqual(deleteFile, deleteFile2); + } else { + assertThat(deleteFile.referencedDataFile()).isEqualTo(dv3.referencedDataFile()); + assertEqual(deleteFile, dv3); + } + } + } + } + + private List actualRecords() { + return spark + .read() + .format("iceberg") + .load(tableLocation) + .as(Encoders.bean(ThreeColumnRecord.class)) + .sort("c1", "c2", "c3") + .collectAsList(); + } + + private void writeRecords(List records) { + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class); + writeDF(df); + } + + private void writeDF(Dataset df) { + df.select("c1", "c2", "c3") + .write() + .format("iceberg") + .option(SparkWriteOptions.DISTRIBUTION_MODE, TableProperties.WRITE_DISTRIBUTION_MODE_NONE) + .mode("append") + .save(tableLocation); + } + + private long computeManifestEntrySizeBytes(List manifests) { + long totalSize = 0L; + int numEntries = 0; + + for (ManifestFile manifest : manifests) { + totalSize += manifest.length(); + numEntries += + manifest.addedFilesCount() + manifest.existingFilesCount() + manifest.deletedFilesCount(); + } + + return totalSize / numEntries; + } + + private void assertManifestsLocation(Iterable manifests) { + assertManifestsLocation(manifests, null); + } + + private void assertManifestsLocation(Iterable manifests, String stagingLocation) { + if (shouldStageManifests && stagingLocation != null) { + assertThat(manifests).allMatch(manifest -> manifest.path().startsWith(stagingLocation)); + } else { + assertThat(manifests).allMatch(manifest -> manifest.path().startsWith(tableLocation)); + } + } + + private ManifestFile writeManifest(Table table, List files) throws IOException { + File manifestFile = File.createTempFile("generated-manifest", ".avro", temp.toFile()); + assertThat(manifestFile.delete()).isTrue(); + OutputFile outputFile = table.io().newOutputFile(manifestFile.getCanonicalPath()); + + ManifestWriter writer = + ManifestFiles.write(formatVersion, table.spec(), outputFile, null); + + try { + for (DataFile file : files) { + writer.add(file); + } + } finally { + writer.close(); + } + + return writer.toManifestFile(); + } + + private DataFile newDataFile(Table table, String partitionPath) { + return newDataFileBuilder(table).withPartitionPath(partitionPath).build(); + } + + private DataFile newDataFile(Table table, StructLike partition) { + return newDataFileBuilder(table).withPartition(partition).build(); + } + + private DataFiles.Builder newDataFileBuilder(Table table) { + return DataFiles.builder(table.spec()) + .withPath("/path/to/data-" + UUID.randomUUID() + ".parquet") + .withFileSizeInBytes(10) + .withRecordCount(1); + } + + private DeleteFile newDeletes(Table table, DataFile dataFile) { + return formatVersion >= 3 ? newDV(table, dataFile) : newDeleteFileWithRef(table, dataFile); + } + + private DeleteFile newDeleteFileWithRef(Table table, DataFile dataFile) { + return FileGenerationUtil.generatePositionDeleteFileWithRef(table, dataFile); + } + + private DeleteFile newDV(Table table, DataFile dataFile) { + return FileGenerationUtil.generateDV(table, dataFile); + } + + private DeleteFile newDeleteFile(Table table, String partitionPath) { + return formatVersion >= 3 + ? FileMetadata.deleteFileBuilder(table.spec()) + .ofPositionDeletes() + .withPath("/path/to/pos-deletes-" + UUID.randomUUID() + ".puffin") + .withFileSizeInBytes(5) + .withPartitionPath(partitionPath) + .withRecordCount(1) + .withContentOffset(ThreadLocalRandom.current().nextInt()) + .withContentSizeInBytes(ThreadLocalRandom.current().nextInt()) + .build() + : FileMetadata.deleteFileBuilder(table.spec()) + .ofPositionDeletes() + .withPath("/path/to/pos-deletes-" + UUID.randomUUID() + ".parquet") + .withFileSizeInBytes(5) + .withPartitionPath(partitionPath) + .withRecordCount(1) + .build(); + } + + private List> generatePosDeletes(String predicate) { + List rows = + spark + .read() + .format("iceberg") + .load(tableLocation) + .selectExpr("_file", "_pos") + .where(predicate) + .collectAsList(); + + List> deletes = Lists.newArrayList(); + + for (Row row : rows) { + deletes.add(Pair.of(row.getString(0), row.getLong(1))); + } + + return deletes; + } + + private Pair writePosDeletes( + Table table, List> deletes) throws IOException { + return writePosDeletes(table, null, deletes); + } + + private Pair writePosDeletes( + Table table, StructLike partition, List> deletes) + throws IOException { + OutputFile outputFile = Files.localOutput(File.createTempFile("junit", null, temp.toFile())); + return FileHelpers.writeDeleteFile(table, outputFile, partition, deletes, formatVersion); + } + + private DeleteFile writeEqDeletes(Table table, String key, Object... values) throws IOException { + return writeEqDeletes(table, null, key, values); + } + + private DeleteFile writeEqDeletes(Table table, StructLike partition, String key, Object... values) + throws IOException { + List deletes = Lists.newArrayList(); + Schema deleteSchema = table.schema().select(key); + Record delete = GenericRecord.create(deleteSchema); + + for (Object value : values) { + deletes.add(delete.copy(key, value)); + } + + OutputFile outputFile = Files.localOutput(File.createTempFile("junit", null, temp.toFile())); + return FileHelpers.writeDeleteFile(table, outputFile, partition, deletes, deleteSchema); + } + + private void assertDeletes(DataFile dataFile, DeleteFile deleteFile) { + assertThat(deleteFile.referencedDataFile()).isEqualTo(dataFile.location()); + if (formatVersion >= 3) { + assertThat(deleteFile.contentOffset()).isNotNull(); + assertThat(deleteFile.contentSizeInBytes()).isNotNull(); + } else { + assertThat(deleteFile.contentOffset()).isNull(); + assertThat(deleteFile.contentSizeInBytes()).isNull(); + } + } + + private void assertEqual(DeleteFile deleteFile1, DeleteFile deleteFile2) { + assertThat(deleteFile1.location()).isEqualTo(deleteFile2.location()); + assertThat(deleteFile1.content()).isEqualTo(deleteFile2.content()); + assertThat(deleteFile1.specId()).isEqualTo(deleteFile2.specId()); + assertThat(deleteFile1.partition()).isEqualTo(deleteFile2.partition()); + assertThat(deleteFile1.format()).isEqualTo(deleteFile2.format()); + assertThat(deleteFile1.referencedDataFile()).isEqualTo(deleteFile2.referencedDataFile()); + assertThat(deleteFile1.contentOffset()).isEqualTo(deleteFile2.contentOffset()); + assertThat(deleteFile1.contentSizeInBytes()).isEqualTo(deleteFile2.contentSizeInBytes()); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestRewritePositionDeleteFilesAction.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestRewritePositionDeleteFilesAction.java new file mode 100644 index 000000000000..8547f9753f5e --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestRewritePositionDeleteFilesAction.java @@ -0,0 +1,1126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.spark.sql.functions.expr; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.BiFunction; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Files; +import org.apache.iceberg.MetadataTableType; +import org.apache.iceberg.MetadataTableUtils; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.PartitionData; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Partitioning; +import org.apache.iceberg.PositionDeletesScanTask; +import org.apache.iceberg.RowDelta; +import org.apache.iceberg.ScanTask; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.actions.RewritePositionDeleteFiles.FileGroupRewriteResult; +import org.apache.iceberg.actions.RewritePositionDeleteFiles.Result; +import org.apache.iceberg.actions.SizeBasedFileRewriter; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.data.FileHelpers; +import org.apache.iceberg.deletes.DeleteGranularity; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.CatalogTestBase; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.spark.source.FourColumnRecord; +import org.apache.iceberg.spark.source.ThreeColumnRecord; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.Pair; +import org.apache.iceberg.util.StructLikeMap; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.StructType; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.io.TempDir; + +public class TestRewritePositionDeleteFilesAction extends CatalogTestBase { + + private static final String TABLE_NAME = "test_table"; + private static final Schema SCHEMA = + new Schema( + optional(1, "c1", Types.IntegerType.get()), + optional(2, "c2", Types.StringType.get()), + optional(3, "c3", Types.StringType.get())); + + private static final Map CATALOG_PROPS = + ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "cache-enabled", "false"); + + private static final int SCALE = 4000; + private static final int DELETES_SCALE = 1000; + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}, fileFormat = {3}") + public static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.HIVE.catalogName(), + SparkCatalogConfig.HIVE.implementation(), + CATALOG_PROPS, + FileFormat.PARQUET + } + }; + } + + @TempDir private Path temp; + + @Parameter(index = 3) + private FileFormat format; + + @AfterEach + public void cleanup() { + validationCatalog.dropTable(TableIdentifier.of("default", TABLE_NAME)); + } + + @TestTemplate + public void testEmptyTable() { + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("c1").build(); + Table table = + validationCatalog.createTable( + TableIdentifier.of("default", TABLE_NAME), SCHEMA, spec, tableProperties()); + + Result result = SparkActions.get(spark).rewritePositionDeletes(table).execute(); + assertThat(result.rewrittenDeleteFilesCount()).as("No rewritten delete files").isZero(); + assertThat(result.addedDeleteFilesCount()).as("No added delete files").isZero(); + } + + @TestTemplate + public void testFileGranularity() throws Exception { + checkDeleteGranularity(DeleteGranularity.FILE); + } + + @TestTemplate + public void testPartitionGranularity() throws Exception { + checkDeleteGranularity(DeleteGranularity.PARTITION); + } + + private void checkDeleteGranularity(DeleteGranularity deleteGranularity) throws Exception { + Table table = createTableUnpartitioned(2, SCALE); + + table + .updateProperties() + .set(TableProperties.DELETE_GRANULARITY, deleteGranularity.toString()) + .commit(); + + List dataFiles = TestHelpers.dataFiles(table); + assertThat(dataFiles).hasSize(2); + + writePosDeletesForFiles(table, 2, DELETES_SCALE, dataFiles); + + List deleteFiles = deleteFiles(table); + assertThat(deleteFiles).hasSize(2); + + Result result = + SparkActions.get(spark) + .rewritePositionDeletes(table) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .execute(); + + int expectedDeleteFilesCount = deleteGranularity == DeleteGranularity.FILE ? 2 : 1; + assertThat(result.addedDeleteFilesCount()).isEqualTo(expectedDeleteFilesCount); + } + + @TestTemplate + public void testUnpartitioned() throws Exception { + Table table = createTableUnpartitioned(2, SCALE); + List dataFiles = TestHelpers.dataFiles(table); + writePosDeletesForFiles(table, 2, DELETES_SCALE, dataFiles); + assertThat(dataFiles).hasSize(2); + + List deleteFiles = deleteFiles(table); + assertThat(deleteFiles).hasSize(2); + + List expectedRecords = records(table); + List expectedDeletes = deleteRecords(table); + assertThat(expectedRecords).hasSize(2000); + assertThat(expectedDeletes).hasSize(2000); + + Result result = + SparkActions.get(spark) + .rewritePositionDeletes(table) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .execute(); + List newDeleteFiles = deleteFiles(table); + assertThat(newDeleteFiles).as("Expected 1 new delete file").hasSize(1); + assertLocallySorted(newDeleteFiles); + assertNotContains(deleteFiles, newDeleteFiles); + checkResult(result, deleteFiles, newDeleteFiles, 1); + checkSequenceNumbers(table, deleteFiles, newDeleteFiles); + + List actualRecords = records(table); + List actualDeletes = deleteRecords(table); + assertEquals("Rows must match", expectedRecords, actualRecords); + assertEquals("Position deletes must match", expectedDeletes, actualDeletes); + } + + @TestTemplate + public void testRewriteAll() throws Exception { + Table table = createTablePartitioned(4, 2, SCALE); + + List dataFiles = TestHelpers.dataFiles(table); + writePosDeletesForFiles(table, 2, DELETES_SCALE, dataFiles); + assertThat(dataFiles).hasSize(4); + + List deleteFiles = deleteFiles(table); + assertThat(deleteFiles).hasSize(8); + + List expectedRecords = records(table); + List expectedDeletes = deleteRecords(table); + assertThat(expectedRecords).hasSize(12000); + assertThat(expectedDeletes).hasSize(4000); + + Result result = + SparkActions.get(spark) + .rewritePositionDeletes(table) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .option(SizeBasedFileRewriter.TARGET_FILE_SIZE_BYTES, Long.toString(Long.MAX_VALUE - 1)) + .execute(); + + List newDeleteFiles = deleteFiles(table); + assertThat(newDeleteFiles).hasSize(4); + assertNotContains(deleteFiles, newDeleteFiles); + assertLocallySorted(newDeleteFiles); + checkResult(result, deleteFiles, newDeleteFiles, 4); + checkSequenceNumbers(table, deleteFiles, newDeleteFiles); + + List actualRecords = records(table); + List actualDeletes = deleteRecords(table); + assertEquals("Rows must match", expectedRecords, actualRecords); + assertEquals("Position deletes must match", expectedDeletes, actualDeletes); + } + + @TestTemplate + public void testRewriteFilter() throws Exception { + Table table = createTablePartitioned(4, 2, SCALE); + table.refresh(); + + List dataFiles = TestHelpers.dataFiles(table); + writePosDeletesForFiles(table, 2, DELETES_SCALE, dataFiles); + assertThat(dataFiles).hasSize(4); + + List deleteFiles = deleteFiles(table); + assertThat(deleteFiles).hasSize(8); + + table.refresh(); + List expectedRecords = records(table); + List expectedDeletes = deleteRecords(table); + assertThat(expectedRecords).hasSize(12000); + assertThat(expectedDeletes).hasSize(4000); + + Expression filter = + Expressions.and( + Expressions.greaterThan("c3", "0"), // should have no effect + Expressions.or(Expressions.equal("c1", 1), Expressions.equal("c1", 2))); + + Result result = + SparkActions.get(spark) + .rewritePositionDeletes(table) + .filter(filter) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .option(SizeBasedFileRewriter.TARGET_FILE_SIZE_BYTES, Long.toString(Long.MAX_VALUE - 1)) + .execute(); + + List newDeleteFiles = except(deleteFiles(table), deleteFiles); + assertThat(newDeleteFiles).as("Should have 4 delete files").hasSize(2); + + List expectedRewrittenFiles = + filterFiles(table, deleteFiles, ImmutableList.of(1), ImmutableList.of(2)); + assertLocallySorted(newDeleteFiles); + checkResult(result, expectedRewrittenFiles, newDeleteFiles, 2); + + List actualRecords = records(table); + List actualDeletes = deleteRecords(table); + assertEquals("Rows must match", expectedRecords, actualRecords); + assertEquals("Position deletes must match", expectedDeletes, actualDeletes); + } + + @TestTemplate + public void testRewriteToSmallerTarget() throws Exception { + Table table = createTablePartitioned(4, 2, SCALE); + + List dataFiles = TestHelpers.dataFiles(table); + writePosDeletesForFiles(table, 2, DELETES_SCALE, dataFiles); + assertThat(dataFiles).hasSize(4); + + List expectedRecords = records(table); + List expectedDeletes = deleteRecords(table); + assertThat(expectedRecords).hasSize(12000); + assertThat(expectedDeletes).hasSize(4000); + + List deleteFiles = deleteFiles(table); + assertThat(deleteFiles).hasSize(8); + + long avgSize = size(deleteFiles) / deleteFiles.size(); + + Result result = + SparkActions.get(spark) + .rewritePositionDeletes(table) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .option(SizeBasedFileRewriter.TARGET_FILE_SIZE_BYTES, String.valueOf(avgSize / 2)) + .execute(); + List newDeleteFiles = deleteFiles(table); + assertThat(newDeleteFiles).as("Should have 8 new delete files").hasSize(8); + assertNotContains(deleteFiles, newDeleteFiles); + assertLocallySorted(newDeleteFiles); + checkResult(result, deleteFiles, newDeleteFiles, 4); + checkSequenceNumbers(table, deleteFiles, newDeleteFiles); + + List actualRecords = records(table); + List actualDeletes = deleteRecords(table); + assertEquals("Rows must match", expectedRecords, actualRecords); + assertEquals("Position deletes must match", expectedDeletes, actualDeletes); + } + + @TestTemplate + public void testRemoveDanglingDeletes() throws Exception { + Table table = createTablePartitioned(4, 2, SCALE); + + List dataFiles = TestHelpers.dataFiles(table); + writePosDeletesForFiles( + table, + 2, + DELETES_SCALE, + dataFiles, + true /* Disable commit-time ManifestFilterManager removal of dangling deletes */); + + assertThat(dataFiles).hasSize(4); + + List deleteFiles = deleteFiles(table); + assertThat(deleteFiles).hasSize(8); + + List expectedRecords = records(table); + List expectedDeletes = deleteRecords(table); + assertThat(expectedRecords).hasSize(12000); + assertThat(expectedDeletes).hasSize(4000); + + SparkActions.get(spark) + .rewriteDataFiles(table) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .execute(); + + Result result = + SparkActions.get(spark) + .rewritePositionDeletes(table) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .execute(); + List newDeleteFiles = deleteFiles(table); + assertThat(newDeleteFiles).as("Should have 0 new delete files").hasSize(0); + assertNotContains(deleteFiles, newDeleteFiles); + assertLocallySorted(newDeleteFiles); + checkResult(result, deleteFiles, newDeleteFiles, 4); + checkSequenceNumbers(table, deleteFiles, newDeleteFiles); + + List actualRecords = records(table); + List actualDeletes = deleteRecords(table); + assertEquals("Rows must match", expectedRecords, actualRecords); + assertThat(actualDeletes).as("Should be no new position deletes").hasSize(0); + } + + @TestTemplate + public void testSomePartitionsDanglingDeletes() throws Exception { + Table table = createTablePartitioned(4, 2, SCALE); + + List dataFiles = TestHelpers.dataFiles(table); + writePosDeletesForFiles(table, 2, DELETES_SCALE, dataFiles); + assertThat(dataFiles).hasSize(4); + + List deleteFiles = deleteFiles(table); + assertThat(deleteFiles).hasSize(8); + + List expectedRecords = records(table); + List expectedDeletes = deleteRecords(table); + assertThat(expectedRecords).hasSize(12000); + assertThat(expectedDeletes).hasSize(4000); + + // Rewrite half the data files + Expression filter = Expressions.or(Expressions.equal("c1", 0), Expressions.equal("c1", 1)); + SparkActions.get(spark) + .rewriteDataFiles(table) + .filter(filter) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .execute(); + + Result result = + SparkActions.get(spark) + .rewritePositionDeletes(table) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .execute(); + List newDeleteFiles = deleteFiles(table); + assertThat(newDeleteFiles).as("Should have 2 new delete files").hasSize(2); + assertNotContains(deleteFiles, newDeleteFiles); + assertLocallySorted(newDeleteFiles); + checkResult(result, deleteFiles, newDeleteFiles, 4); + checkSequenceNumbers(table, deleteFiles, newDeleteFiles); + + // As only half the files have been rewritten, + // we expect to retain position deletes only for those not rewritten + expectedDeletes = + expectedDeletes.stream() + .filter( + r -> { + Object[] partition = (Object[]) r[3]; + return partition[0] == (Integer) 2 || partition[0] == (Integer) 3; + }) + .collect(Collectors.toList()); + + List actualRecords = records(table); + List actualDeletes = deleteRecords(table); + assertEquals("Rows must match", expectedRecords, actualRecords); + assertEquals("Position deletes must match", expectedDeletes, actualDeletes); + } + + @TestTemplate + public void testRewriteFilterRemoveDangling() throws Exception { + Table table = createTablePartitioned(4, 2, SCALE); + table.refresh(); + + List dataFiles = TestHelpers.dataFiles(table); + writePosDeletesForFiles(table, 2, DELETES_SCALE, dataFiles, true); + assertThat(dataFiles).hasSize(4); + + List deleteFiles = deleteFiles(table); + assertThat(deleteFiles).hasSize(8); + + table.refresh(); + List expectedRecords = records(table); + List expectedDeletes = deleteRecords(table); + assertThat(expectedRecords).hasSize(12000); // 16000 data - 4000 delete rows + assertThat(expectedDeletes).hasSize(4000); + + SparkActions.get(spark) + .rewriteDataFiles(table) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .execute(); + + Expression filter = Expressions.or(Expressions.equal("c1", 0), Expressions.equal("c1", 1)); + Result result = + SparkActions.get(spark) + .rewritePositionDeletes(table) + .filter(filter) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .option(SizeBasedFileRewriter.TARGET_FILE_SIZE_BYTES, Long.toString(Long.MAX_VALUE - 1)) + .execute(); + + List newDeleteFiles = except(deleteFiles(table), deleteFiles); + assertThat(newDeleteFiles).as("Should have 2 new delete files").hasSize(0); + + List expectedRewrittenFiles = + filterFiles(table, deleteFiles, ImmutableList.of(0), ImmutableList.of(1)); + checkResult(result, expectedRewrittenFiles, newDeleteFiles, 2); + + List actualRecords = records(table); + List allDeletes = deleteRecords(table); + // Only non-compacted deletes remain + List expectedDeletesFiltered = + filterDeletes(expectedDeletes, ImmutableList.of(2), ImmutableList.of(3)); + assertEquals("Rows must match", expectedRecords, actualRecords); + assertEquals("Position deletes must match", expectedDeletesFiltered, allDeletes); + } + + @TestTemplate + public void testPartitionEvolutionAdd() throws Exception { + Table table = createTableUnpartitioned(2, SCALE); + List unpartitionedDataFiles = TestHelpers.dataFiles(table); + writePosDeletesForFiles(table, 2, DELETES_SCALE, unpartitionedDataFiles); + assertThat(unpartitionedDataFiles).hasSize(2); + + List unpartitionedDeleteFiles = deleteFiles(table); + assertThat(unpartitionedDeleteFiles).hasSize(2); + + List expectedUnpartitionedDeletes = deleteRecords(table); + List expectedUnpartitionedRecords = records(table); + assertThat(expectedUnpartitionedRecords).hasSize(2000); + assertThat(expectedUnpartitionedDeletes).hasSize(2000); + + table.updateSpec().addField("c1").commit(); + writeRecords(table, 2, SCALE, 2); + List partitionedDataFiles = + except(TestHelpers.dataFiles(table), unpartitionedDataFiles); + writePosDeletesForFiles(table, 2, DELETES_SCALE, partitionedDataFiles); + assertThat(partitionedDataFiles).hasSize(2); + + List partitionedDeleteFiles = except(deleteFiles(table), unpartitionedDeleteFiles); + assertThat(partitionedDeleteFiles).hasSize(4); + + List expectedDeletes = deleteRecords(table); + List expectedRecords = records(table); + assertThat(expectedDeletes).hasSize(4000); + assertThat(expectedRecords).hasSize(8000); + + Result result = + SparkActions.get(spark) + .rewritePositionDeletes(table) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .execute(); + + List rewrittenDeleteFiles = + Stream.concat(unpartitionedDeleteFiles.stream(), partitionedDeleteFiles.stream()) + .collect(Collectors.toList()); + List newDeleteFiles = deleteFiles(table); + assertThat(newDeleteFiles).as("Should have 3 new delete files").hasSize(3); + assertNotContains(rewrittenDeleteFiles, newDeleteFiles); + assertLocallySorted(newDeleteFiles); + checkResult(result, rewrittenDeleteFiles, newDeleteFiles, 3); + checkSequenceNumbers(table, rewrittenDeleteFiles, newDeleteFiles); + + List actualRecords = records(table); + List actualDeletes = deleteRecords(table); + assertEquals("Rows must match", expectedRecords, actualRecords); + assertEquals("Position deletes must match", expectedDeletes, actualDeletes); + } + + @TestTemplate + public void testPartitionEvolutionRemove() throws Exception { + Table table = createTablePartitioned(2, 2, SCALE); + List dataFilesUnpartitioned = TestHelpers.dataFiles(table); + writePosDeletesForFiles(table, 2, DELETES_SCALE, dataFilesUnpartitioned); + assertThat(dataFilesUnpartitioned).hasSize(2); + + List deleteFilesUnpartitioned = deleteFiles(table); + assertThat(deleteFilesUnpartitioned).hasSize(4); + + table.updateSpec().removeField("c1").commit(); + + writeRecords(table, 2, SCALE); + List dataFilesPartitioned = + except(TestHelpers.dataFiles(table), dataFilesUnpartitioned); + writePosDeletesForFiles(table, 2, DELETES_SCALE, dataFilesPartitioned); + assertThat(dataFilesPartitioned).hasSize(2); + + List deleteFilesPartitioned = except(deleteFiles(table), deleteFilesUnpartitioned); + assertThat(deleteFilesPartitioned).hasSize(2); + + List expectedRecords = records(table); + List expectedDeletes = deleteRecords(table); + assertThat(expectedDeletes).hasSize(4000); + assertThat(expectedRecords).hasSize(8000); + + List expectedRewritten = deleteFiles(table); + assertThat(expectedRewritten).hasSize(6); + + Result result = + SparkActions.get(spark) + .rewritePositionDeletes(table) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .execute(); + List newDeleteFiles = deleteFiles(table); + assertThat(newDeleteFiles).as("Should have 3 new delete files").hasSize(3); + assertNotContains(expectedRewritten, newDeleteFiles); + assertLocallySorted(newDeleteFiles); + checkResult(result, expectedRewritten, newDeleteFiles, 3); + checkSequenceNumbers(table, expectedRewritten, newDeleteFiles); + + List actualRecords = records(table); + List actualDeletes = deleteRecords(table); + assertEquals("Rows must match", expectedRecords, actualRecords); + assertEquals("Position deletes must match", expectedDeletes, actualDeletes); + } + + @TestTemplate + public void testSchemaEvolution() throws Exception { + Table table = createTablePartitioned(2, 2, SCALE); + List dataFiles = TestHelpers.dataFiles(table); + writePosDeletesForFiles(table, 2, DELETES_SCALE, dataFiles); + assertThat(dataFiles).hasSize(2); + + List deleteFiles = deleteFiles(table); + assertThat(deleteFiles).hasSize(4); + + table.updateSchema().addColumn("c4", Types.StringType.get()).commit(); + writeNewSchemaRecords(table, 2, SCALE, 2, 2); + + int newColId = table.schema().findField("c4").fieldId(); + List newSchemaDataFiles = + TestHelpers.dataFiles(table).stream() + .filter(f -> f.upperBounds().containsKey(newColId)) + .collect(Collectors.toList()); + writePosDeletesForFiles(table, 2, DELETES_SCALE, newSchemaDataFiles); + + List newSchemaDeleteFiles = except(deleteFiles(table), deleteFiles); + assertThat(newSchemaDeleteFiles).hasSize(4); + + table.refresh(); + List expectedDeletes = deleteRecords(table); + List expectedRecords = records(table); + assertThat(expectedDeletes).hasSize(4000); // 4 files * 1000 per file + assertThat(expectedRecords).hasSize(12000); // 4 * 4000 - 4000 + + Result result = + SparkActions.get(spark) + .rewritePositionDeletes(table) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .execute(); + + List rewrittenDeleteFiles = + Stream.concat(deleteFiles.stream(), newSchemaDeleteFiles.stream()) + .collect(Collectors.toList()); + List newDeleteFiles = deleteFiles(table); + assertThat(newDeleteFiles).as("Should have 2 new delete files").hasSize(4); + assertNotContains(rewrittenDeleteFiles, newDeleteFiles); + assertLocallySorted(newDeleteFiles); + checkResult(result, rewrittenDeleteFiles, newDeleteFiles, 4); + checkSequenceNumbers(table, rewrittenDeleteFiles, newDeleteFiles); + + List actualRecords = records(table); + assertEquals("Rows must match", expectedRecords, actualRecords); + } + + @TestTemplate + public void testSnapshotProperty() throws Exception { + Table table = createTableUnpartitioned(2, SCALE); + List dataFiles = TestHelpers.dataFiles(table); + writePosDeletesForFiles(table, 2, DELETES_SCALE, dataFiles); + assertThat(dataFiles).hasSize(2); + + List deleteFiles = deleteFiles(table); + assertThat(deleteFiles).hasSize(2); + + Result ignored = + SparkActions.get(spark) + .rewritePositionDeletes(table) + .snapshotProperty("key", "value") + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .execute(); + assertThat(table.currentSnapshot().summary()) + .containsAllEntriesOf(ImmutableMap.of("key", "value")); + + // make sure internal produced properties are not lost + String[] commitMetricsKeys = + new String[] { + SnapshotSummary.ADDED_DELETE_FILES_PROP, + SnapshotSummary.ADDED_POS_DELETES_PROP, + SnapshotSummary.CHANGED_PARTITION_COUNT_PROP, + SnapshotSummary.REMOVED_DELETE_FILES_PROP, + SnapshotSummary.REMOVED_POS_DELETES_PROP, + SnapshotSummary.TOTAL_DATA_FILES_PROP, + SnapshotSummary.TOTAL_DELETE_FILES_PROP, + }; + assertThat(table.currentSnapshot().summary()).containsKeys(commitMetricsKeys); + } + + @TestTemplate + public void testRewriteManyColumns() throws Exception { + List fields = + Lists.newArrayList(Types.NestedField.required(0, "id", Types.LongType.get())); + List additionalCols = + IntStream.range(1, 1010) + .mapToObj(i -> Types.NestedField.optional(i, "c" + i, Types.StringType.get())) + .collect(Collectors.toList()); + fields.addAll(additionalCols); + Schema schema = new Schema(fields); + PartitionSpec spec = PartitionSpec.builderFor(schema).bucket("id", 2).build(); + Table table = + validationCatalog.createTable( + TableIdentifier.of("default", TABLE_NAME), schema, spec, tableProperties()); + + Dataset df = + spark + .range(4) + .withColumns( + IntStream.range(1, 1010) + .boxed() + .collect(Collectors.toMap(i -> "c" + i, i -> expr("CAST(id as STRING)")))); + StructType sparkSchema = spark.table(name(table)).schema(); + spark + .createDataFrame(df.rdd(), sparkSchema) + .coalesce(1) + .write() + .format("iceberg") + .mode("append") + .save(name(table)); + + List dataFiles = TestHelpers.dataFiles(table); + writePosDeletesForFiles(table, 1, 1, dataFiles); + assertThat(dataFiles).hasSize(2); + + List deleteFiles = deleteFiles(table); + assertThat(deleteFiles).hasSize(2); + + List expectedRecords = records(table); + List expectedDeletes = deleteRecords(table); + assertThat(expectedRecords).hasSize(2); + assertThat(expectedDeletes).hasSize(2); + + Result result = + SparkActions.get(spark) + .rewritePositionDeletes(table) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .option(SizeBasedFileRewriter.TARGET_FILE_SIZE_BYTES, Long.toString(Long.MAX_VALUE - 1)) + .execute(); + + List newDeleteFiles = deleteFiles(table); + assertThat(newDeleteFiles).hasSize(2); + assertNotContains(deleteFiles, newDeleteFiles); + assertLocallySorted(newDeleteFiles); + checkResult(result, deleteFiles, newDeleteFiles, 2); + checkSequenceNumbers(table, deleteFiles, newDeleteFiles); + + List actualRecords = records(table); + List actualDeletes = deleteRecords(table); + assertEquals("Rows must match", expectedRecords, actualRecords); + assertEquals("Position deletes must match", expectedDeletes, actualDeletes); + } + + private Table createTablePartitioned(int partitions, int files, int numRecords) { + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("c1").build(); + Table table = + validationCatalog.createTable( + TableIdentifier.of("default", TABLE_NAME), SCHEMA, spec, tableProperties()); + + writeRecords(table, files, numRecords, partitions); + return table; + } + + private Table createTableUnpartitioned(int files, int numRecords) { + Table table = + validationCatalog.createTable( + TableIdentifier.of("default", TABLE_NAME), + SCHEMA, + PartitionSpec.unpartitioned(), + tableProperties()); + + writeRecords(table, files, numRecords); + return table; + } + + private Map tableProperties() { + return ImmutableMap.of( + TableProperties.DEFAULT_WRITE_METRICS_MODE, + "full", + TableProperties.FORMAT_VERSION, + "2", + TableProperties.DEFAULT_FILE_FORMAT, + format.toString()); + } + + private void writeRecords(Table table, int files, int numRecords) { + writeRecords(table, files, numRecords, 1); + } + + private void writeRecords(Table table, int files, int numRecords, int numPartitions) { + writeRecordsWithPartitions( + table, + files, + numRecords, + IntStream.range(0, numPartitions).mapToObj(ImmutableList::of).collect(Collectors.toList())); + } + + private void writeRecordsWithPartitions( + Table table, int files, int numRecords, List> partitions) { + int partitionTypeSize = table.spec().partitionType().fields().size(); + assertThat(partitionTypeSize) + .as("This method currently supports only two columns as partition columns") + .isLessThanOrEqualTo(2); + + BiFunction, ThreeColumnRecord> recordFunction = + (i, partValues) -> { + switch (partitionTypeSize) { + case (0): + return new ThreeColumnRecord(i, String.valueOf(i), String.valueOf(i)); + case (1): + return new ThreeColumnRecord(partValues.get(0), String.valueOf(i), String.valueOf(i)); + case (2): + return new ThreeColumnRecord( + partValues.get(0), String.valueOf(partValues.get(1)), String.valueOf(i)); + default: + throw new ValidationException( + "This method currently supports only two columns as partition columns"); + } + }; + List records = + partitions.stream() + .flatMap( + partition -> + IntStream.range(0, numRecords) + .mapToObj(i -> recordFunction.apply(i, partition))) + .collect(Collectors.toList()); + spark + .createDataFrame(records, ThreeColumnRecord.class) + .repartition(files) + .write() + .format("iceberg") + .mode("append") + .save(name(table)); + table.refresh(); + } + + private void writeNewSchemaRecords( + Table table, int files, int numRecords, int startingPartition, int partitions) { + List records = + IntStream.range(startingPartition, startingPartition + partitions) + .boxed() + .flatMap( + partition -> + IntStream.range(0, numRecords) + .mapToObj( + i -> + new FourColumnRecord( + partition, + String.valueOf(i), + String.valueOf(i), + String.valueOf(i)))) + .collect(Collectors.toList()); + spark + .createDataFrame(records, FourColumnRecord.class) + .repartition(files) + .write() + .format("iceberg") + .mode("append") + .save(name(table)); + } + + private List records(Table table) { + return rowsToJava( + spark.read().format("iceberg").load(name(table)).sort("c1", "c2", "c3").collectAsList()); + } + + private List deleteRecords(Table table) { + String[] additionalFields; + // do not select delete_file_path for comparison + // as delete files have been rewritten + if (table.spec().isUnpartitioned()) { + additionalFields = new String[] {"pos", "row"}; + } else { + additionalFields = new String[] {"pos", "row", "partition", "spec_id"}; + } + return rowsToJava( + spark + .read() + .format("iceberg") + .load(name(table) + ".position_deletes") + .select("file_path", additionalFields) + .sort("file_path", "pos") + .collectAsList()); + } + + private void writePosDeletesForFiles( + Table table, int deleteFilesPerPartition, int deletesPerDataFile, List files) + throws IOException { + writePosDeletesForFiles(table, deleteFilesPerPartition, deletesPerDataFile, files, false); + } + + private void writePosDeletesForFiles( + Table table, + int deleteFilesPerPartition, + int deletesPerDataFile, + List files, + boolean transactional) + throws IOException { + + Map> filesByPartition = + files.stream().collect(Collectors.groupingBy(ContentFile::partition)); + List deleteFiles = + Lists.newArrayListWithCapacity(deleteFilesPerPartition * filesByPartition.size()); + String suffix = String.format(".%s", FileFormat.PARQUET.name().toLowerCase()); + + for (Map.Entry> filesByPartitionEntry : + filesByPartition.entrySet()) { + + StructLike partition = filesByPartitionEntry.getKey(); + List partitionFiles = filesByPartitionEntry.getValue(); + + int deletesForPartition = partitionFiles.size() * deletesPerDataFile; + assertThat(deletesForPartition % deleteFilesPerPartition) + .as( + "Number of delete files per partition should be " + + "evenly divisible by requested deletes per data file times number of data files in this partition") + .isZero(); + + int deleteFileSize = deletesForPartition / deleteFilesPerPartition; + int counter = 0; + List> deletes = Lists.newArrayList(); + for (DataFile partitionFile : partitionFiles) { + for (int deletePos = 0; deletePos < deletesPerDataFile; deletePos++) { + deletes.add(Pair.of(partitionFile.path(), (long) deletePos)); + counter++; + if (counter == deleteFileSize) { + // Dump to file and reset variables + OutputFile output = + Files.localOutput(File.createTempFile("junit", suffix, temp.toFile())); + deleteFiles.add(FileHelpers.writeDeleteFile(table, output, partition, deletes).first()); + counter = 0; + deletes.clear(); + } + } + } + } + + if (transactional) { + RowDelta rowDelta = table.newRowDelta(); + deleteFiles.forEach(rowDelta::addDeletes); + rowDelta.commit(); + } else { + deleteFiles.forEach( + deleteFile -> { + RowDelta rowDelta = table.newRowDelta(); + rowDelta.addDeletes(deleteFile); + rowDelta.commit(); + }); + } + } + + private List deleteFiles(Table table) { + Table deletesTable = + MetadataTableUtils.createMetadataTableInstance(table, MetadataTableType.POSITION_DELETES); + CloseableIterable tasks = deletesTable.newBatchScan().planFiles(); + return Lists.newArrayList( + CloseableIterable.transform(tasks, t -> ((PositionDeletesScanTask) t).file())); + } + + private > List except(List first, List second) { + Set secondPaths = + second.stream().map(f -> f.path().toString()).collect(Collectors.toSet()); + return first.stream() + .filter(f -> !secondPaths.contains(f.path().toString())) + .collect(Collectors.toList()); + } + + private void assertNotContains(List original, List rewritten) { + Set originalPaths = + original.stream().map(f -> f.path().toString()).collect(Collectors.toSet()); + Set rewrittenPaths = + rewritten.stream().map(f -> f.path().toString()).collect(Collectors.toSet()); + rewrittenPaths.retainAll(originalPaths); + assertThat(rewrittenPaths).hasSize(0); + } + + private void assertLocallySorted(List deleteFiles) { + for (DeleteFile deleteFile : deleteFiles) { + Dataset deletes = + spark.read().format("iceberg").load("default." + TABLE_NAME + ".position_deletes"); + deletes.filter(deletes.col("delete_file_path").equalTo(deleteFile.path().toString())); + List rows = deletes.collectAsList(); + assertThat(rows).as("Empty delete file found").isNotEmpty(); + int lastPos = 0; + String lastPath = ""; + for (Row row : rows) { + String path = row.getAs("file_path"); + long pos = row.getAs("pos"); + if (path.compareTo(lastPath) < 0) { + fail(String.format("File_path not sorted, Found %s after %s", path, lastPath)); + } else if (path.equals(lastPath)) { + assertThat(pos).as("Pos not sorted").isGreaterThanOrEqualTo(lastPos); + } + } + } + } + + private String name(Table table) { + String[] splits = table.name().split("\\."); + + assertThat(splits).hasSize(3); + return String.format("%s.%s", splits[1], splits[2]); + } + + private long size(List deleteFiles) { + return deleteFiles.stream().mapToLong(DeleteFile::fileSizeInBytes).sum(); + } + + private List filterDeletes(List deletes, List... partitionValues) { + Stream matches = + deletes.stream() + .filter( + r -> { + Object[] partition = (Object[]) r[3]; + return Arrays.stream(partitionValues) + .map(partitionValue -> match(partition, partitionValue)) + .reduce((a, b) -> a || b) + .get(); + }); + return sorted(matches).collect(Collectors.toList()); + } + + private boolean match(Object[] partition, List expectedPartition) { + return IntStream.range(0, expectedPartition.size()) + .mapToObj(j -> partition[j] == expectedPartition.get(j)) + .reduce((a, b) -> a && b) + .get(); + } + + private Stream sorted(Stream deletes) { + return deletes.sorted( + (a, b) -> { + String aFilePath = (String) a[0]; + String bFilePath = (String) b[0]; + int filePathCompare = aFilePath.compareTo(bFilePath); + if (filePathCompare != 0) { + return filePathCompare; + } else { + long aPos = (long) a[1]; + long bPos = (long) b[1]; + return Long.compare(aPos, bPos); + } + }); + } + + private List filterFiles( + Table table, List files, List... partitionValues) { + List partitionTypes = + table.specs().values().stream() + .map(PartitionSpec::partitionType) + .collect(Collectors.toList()); + List partitionDatas = + Arrays.stream(partitionValues) + .map( + partitionValue -> { + Types.StructType thisType = + partitionTypes.stream() + .filter(f -> f.fields().size() == partitionValue.size()) + .findFirst() + .get(); + PartitionData partition = new PartitionData(thisType); + for (int i = 0; i < partitionValue.size(); i++) { + partition.set(i, partitionValue.get(i)); + } + return partition; + }) + .collect(Collectors.toList()); + + return files.stream() + .filter(f -> partitionDatas.stream().anyMatch(data -> f.partition().equals(data))) + .collect(Collectors.toList()); + } + + private void checkResult( + Result result, + List rewrittenDeletes, + List newDeletes, + int expectedGroups) { + assertThat(rewrittenDeletes.size()) + .as("Expected rewritten delete file count does not match") + .isEqualTo(result.rewrittenDeleteFilesCount()); + + assertThat(newDeletes.size()) + .as("Expected new delete file count does not match") + .isEqualTo(result.addedDeleteFilesCount()); + + assertThat(size(rewrittenDeletes)) + .as("Expected rewritten delete byte count does not match") + .isEqualTo(result.rewrittenBytesCount()); + + assertThat(size(newDeletes)) + .as("Expected new delete byte count does not match") + .isEqualTo(result.addedBytesCount()); + + assertThat(expectedGroups) + .as("Expected rewrite group count does not match") + .isEqualTo(result.rewriteResults().size()); + + assertThat(rewrittenDeletes.size()) + .as("Expected rewritten delete file count in all groups to match") + .isEqualTo( + result.rewriteResults().stream() + .mapToInt(FileGroupRewriteResult::rewrittenDeleteFilesCount) + .sum()); + + assertThat(newDeletes.size()) + .as("Expected added delete file count in all groups to match") + .isEqualTo( + result.rewriteResults().stream() + .mapToInt(FileGroupRewriteResult::addedDeleteFilesCount) + .sum()); + + assertThat(size(rewrittenDeletes)) + .as("Expected rewritten delete bytes in all groups to match") + .isEqualTo( + result.rewriteResults().stream() + .mapToLong(FileGroupRewriteResult::rewrittenBytesCount) + .sum()); + + assertThat(size(newDeletes)) + .as("Expected added delete bytes in all groups to match") + .isEqualTo( + result.rewriteResults().stream() + .mapToLong(FileGroupRewriteResult::addedBytesCount) + .sum()); + } + + private void checkSequenceNumbers( + Table table, List rewrittenDeletes, List addedDeletes) { + StructLikeMap> rewrittenFilesPerPartition = + groupPerPartition(table, rewrittenDeletes); + StructLikeMap> addedFilesPerPartition = groupPerPartition(table, addedDeletes); + for (StructLike partition : rewrittenFilesPerPartition.keySet()) { + Long maxRewrittenSeq = + rewrittenFilesPerPartition.get(partition).stream() + .mapToLong(ContentFile::dataSequenceNumber) + .max() + .getAsLong(); + List addedPartitionFiles = addedFilesPerPartition.get(partition); + if (addedPartitionFiles != null) { + addedPartitionFiles.forEach( + d -> + assertThat(d.dataSequenceNumber()) + .as("Sequence number should be max of rewritten set") + .isEqualTo(maxRewrittenSeq)); + } + } + } + + private StructLikeMap> groupPerPartition( + Table table, List deleteFiles) { + StructLikeMap> result = + StructLikeMap.create(Partitioning.partitionType(table)); + for (DeleteFile deleteFile : deleteFiles) { + StructLike partition = deleteFile.partition(); + List partitionFiles = result.get(partition); + if (partitionFiles == null) { + partitionFiles = Lists.newArrayList(); + } + partitionFiles.add(deleteFile); + result.put(partition, partitionFiles); + } + return result; + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestSnapshotTableAction.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestSnapshotTableAction.java new file mode 100644 index 000000000000..d9c42a07b853 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestSnapshotTableAction.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.IOException; +import java.nio.file.Files; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.spark.CatalogTestBase; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestSnapshotTableAction extends CatalogTestBase { + private static final String SOURCE_NAME = "spark_catalog.default.source"; + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS %s PURGE", SOURCE_NAME); + } + + @TestTemplate + public void testSnapshotWithParallelTasks() throws IOException { + String location = Files.createTempDirectory(temp, "junit").toFile().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + SOURCE_NAME, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", SOURCE_NAME); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", SOURCE_NAME); + + AtomicInteger snapshotThreadsIndex = new AtomicInteger(0); + SparkActions.get() + .snapshotTable(SOURCE_NAME) + .as(tableName) + .executeWith( + Executors.newFixedThreadPool( + 4, + runnable -> { + Thread thread = new Thread(runnable); + thread.setName("table-snapshot-" + snapshotThreadsIndex.getAndIncrement()); + thread.setDaemon(true); + return thread; + })) + .execute(); + assertThat(snapshotThreadsIndex.get()).isEqualTo(2); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestSparkFileRewriter.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestSparkFileRewriter.java new file mode 100644 index 000000000000..e223d2e16411 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/actions/TestSparkFileRewriter.java @@ -0,0 +1,395 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.MockFileScanTask; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.SizeBasedDataRewriter; +import org.apache.iceberg.actions.SizeBasedFileRewriter; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.spark.TestBase; +import org.apache.iceberg.types.Types.IntegerType; +import org.apache.iceberg.types.Types.NestedField; +import org.apache.iceberg.types.Types.StringType; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +public class TestSparkFileRewriter extends TestBase { + + private static final TableIdentifier TABLE_IDENT = TableIdentifier.of("default", "tbl"); + private static final Schema SCHEMA = + new Schema( + NestedField.required(1, "id", IntegerType.get()), + NestedField.required(2, "dep", StringType.get())); + private static final PartitionSpec SPEC = + PartitionSpec.builderFor(SCHEMA).identity("dep").build(); + private static final SortOrder SORT_ORDER = SortOrder.builderFor(SCHEMA).asc("id").build(); + + @AfterEach + public void removeTable() { + catalog.dropTable(TABLE_IDENT); + } + + @Test + public void testBinPackDataSelectFiles() { + Table table = catalog.createTable(TABLE_IDENT, SCHEMA); + SparkBinPackDataRewriter rewriter = new SparkBinPackDataRewriter(spark, table); + + checkDataFileSizeFiltering(rewriter); + checkDataFilesDeleteThreshold(rewriter); + checkDataFileGroupWithEnoughFiles(rewriter); + checkDataFileGroupWithEnoughData(rewriter); + checkDataFileGroupWithTooMuchData(rewriter); + } + + @Test + public void testSortDataSelectFiles() { + Table table = catalog.createTable(TABLE_IDENT, SCHEMA); + SparkSortDataRewriter rewriter = new SparkSortDataRewriter(spark, table, SORT_ORDER); + + checkDataFileSizeFiltering(rewriter); + checkDataFilesDeleteThreshold(rewriter); + checkDataFileGroupWithEnoughFiles(rewriter); + checkDataFileGroupWithEnoughData(rewriter); + checkDataFileGroupWithTooMuchData(rewriter); + } + + @Test + public void testZOrderDataSelectFiles() { + Table table = catalog.createTable(TABLE_IDENT, SCHEMA); + ImmutableList zOrderCols = ImmutableList.of("id"); + SparkZOrderDataRewriter rewriter = new SparkZOrderDataRewriter(spark, table, zOrderCols); + + checkDataFileSizeFiltering(rewriter); + checkDataFilesDeleteThreshold(rewriter); + checkDataFileGroupWithEnoughFiles(rewriter); + checkDataFileGroupWithEnoughData(rewriter); + checkDataFileGroupWithTooMuchData(rewriter); + } + + private void checkDataFileSizeFiltering(SizeBasedDataRewriter rewriter) { + FileScanTask tooSmallTask = new MockFileScanTask(100L); + FileScanTask optimal = new MockFileScanTask(450); + FileScanTask tooBigTask = new MockFileScanTask(1000L); + List tasks = ImmutableList.of(tooSmallTask, optimal, tooBigTask); + + Map options = + ImmutableMap.of( + SizeBasedDataRewriter.MIN_FILE_SIZE_BYTES, "250", + SizeBasedDataRewriter.TARGET_FILE_SIZE_BYTES, "500", + SizeBasedDataRewriter.MAX_FILE_SIZE_BYTES, "750", + SizeBasedDataRewriter.DELETE_FILE_THRESHOLD, String.valueOf(Integer.MAX_VALUE)); + rewriter.init(options); + + Iterable> groups = rewriter.planFileGroups(tasks); + assertThat(groups).as("Must have 1 group").hasSize(1); + List group = Iterables.getOnlyElement(groups); + assertThat(group).as("Must rewrite 2 files").hasSize(2); + } + + private void checkDataFilesDeleteThreshold(SizeBasedDataRewriter rewriter) { + FileScanTask tooManyDeletesTask = MockFileScanTask.mockTaskWithDeletes(1000L, 3); + FileScanTask optimalTask = MockFileScanTask.mockTaskWithDeletes(1000L, 1); + List tasks = ImmutableList.of(tooManyDeletesTask, optimalTask); + + Map options = + ImmutableMap.of( + SizeBasedDataRewriter.MIN_FILE_SIZE_BYTES, "1", + SizeBasedDataRewriter.TARGET_FILE_SIZE_BYTES, "2000", + SizeBasedDataRewriter.MAX_FILE_SIZE_BYTES, "5000", + SizeBasedDataRewriter.DELETE_FILE_THRESHOLD, "2"); + rewriter.init(options); + + Iterable> groups = rewriter.planFileGroups(tasks); + assertThat(groups).as("Must have 1 group").hasSize(1); + List group = Iterables.getOnlyElement(groups); + assertThat(group).as("Must rewrite 1 file").hasSize(1); + } + + private void checkDataFileGroupWithEnoughFiles(SizeBasedDataRewriter rewriter) { + List tasks = + ImmutableList.of( + new MockFileScanTask(100L), + new MockFileScanTask(100L), + new MockFileScanTask(100L), + new MockFileScanTask(100L)); + + Map options = + ImmutableMap.of( + SizeBasedDataRewriter.MIN_INPUT_FILES, "3", + SizeBasedDataRewriter.MIN_FILE_SIZE_BYTES, "150", + SizeBasedDataRewriter.TARGET_FILE_SIZE_BYTES, "1000", + SizeBasedDataRewriter.MAX_FILE_SIZE_BYTES, "5000", + SizeBasedDataRewriter.DELETE_FILE_THRESHOLD, String.valueOf(Integer.MAX_VALUE)); + rewriter.init(options); + + Iterable> groups = rewriter.planFileGroups(tasks); + assertThat(groups).as("Must have 1 group").hasSize(1); + List group = Iterables.getOnlyElement(groups); + assertThat(group).as("Must rewrite 4 files").hasSize(4); + } + + private void checkDataFileGroupWithEnoughData(SizeBasedDataRewriter rewriter) { + List tasks = + ImmutableList.of( + new MockFileScanTask(100L), new MockFileScanTask(100L), new MockFileScanTask(100L)); + + Map options = + ImmutableMap.of( + SizeBasedDataRewriter.MIN_INPUT_FILES, "5", + SizeBasedDataRewriter.MIN_FILE_SIZE_BYTES, "200", + SizeBasedDataRewriter.TARGET_FILE_SIZE_BYTES, "250", + SizeBasedDataRewriter.MAX_FILE_SIZE_BYTES, "500", + SizeBasedDataRewriter.DELETE_FILE_THRESHOLD, String.valueOf(Integer.MAX_VALUE)); + rewriter.init(options); + + Iterable> groups = rewriter.planFileGroups(tasks); + assertThat(groups).as("Must have 1 group").hasSize(1); + List group = Iterables.getOnlyElement(groups); + assertThat(group).as("Must rewrite 3 files").hasSize(3); + } + + private void checkDataFileGroupWithTooMuchData(SizeBasedDataRewriter rewriter) { + List tasks = ImmutableList.of(new MockFileScanTask(2000L)); + + Map options = + ImmutableMap.of( + SizeBasedDataRewriter.MIN_INPUT_FILES, "5", + SizeBasedDataRewriter.MIN_FILE_SIZE_BYTES, "200", + SizeBasedDataRewriter.TARGET_FILE_SIZE_BYTES, "250", + SizeBasedDataRewriter.MAX_FILE_SIZE_BYTES, "500", + SizeBasedDataRewriter.DELETE_FILE_THRESHOLD, String.valueOf(Integer.MAX_VALUE)); + rewriter.init(options); + + Iterable> groups = rewriter.planFileGroups(tasks); + assertThat(groups).as("Must have 1 group").hasSize(1); + List group = Iterables.getOnlyElement(groups); + assertThat(group).as("Must rewrite big file").hasSize(1); + } + + @Test + public void testInvalidConstructorUsagesSortData() { + Table table = catalog.createTable(TABLE_IDENT, SCHEMA); + + assertThatThrownBy(() -> new SparkSortDataRewriter(spark, table)) + .hasMessageContaining("Cannot sort data without a valid sort order") + .hasMessageContaining("is unsorted and no sort order is provided"); + + assertThatThrownBy(() -> new SparkSortDataRewriter(spark, table, null)) + .hasMessageContaining("Cannot sort data without a valid sort order") + .hasMessageContaining("the provided sort order is null or empty"); + + assertThatThrownBy(() -> new SparkSortDataRewriter(spark, table, SortOrder.unsorted())) + .hasMessageContaining("Cannot sort data without a valid sort order") + .hasMessageContaining("the provided sort order is null or empty"); + } + + @Test + public void testInvalidConstructorUsagesZOrderData() { + Table table = catalog.createTable(TABLE_IDENT, SCHEMA, SPEC); + + assertThatThrownBy(() -> new SparkZOrderDataRewriter(spark, table, null)) + .hasMessageContaining("Cannot ZOrder when no columns are specified"); + + assertThatThrownBy(() -> new SparkZOrderDataRewriter(spark, table, ImmutableList.of())) + .hasMessageContaining("Cannot ZOrder when no columns are specified"); + + assertThatThrownBy(() -> new SparkZOrderDataRewriter(spark, table, ImmutableList.of("dep"))) + .hasMessageContaining("Cannot ZOrder") + .hasMessageContaining("all columns provided were identity partition columns"); + + assertThatThrownBy(() -> new SparkZOrderDataRewriter(spark, table, ImmutableList.of("DeP"))) + .hasMessageContaining("Cannot ZOrder") + .hasMessageContaining("all columns provided were identity partition columns"); + } + + @Test + public void testBinPackDataValidOptions() { + Table table = catalog.createTable(TABLE_IDENT, SCHEMA); + SparkBinPackDataRewriter rewriter = new SparkBinPackDataRewriter(spark, table); + + assertThat(rewriter.validOptions()) + .as("Rewriter must report all supported options") + .isEqualTo( + ImmutableSet.of( + SparkBinPackDataRewriter.TARGET_FILE_SIZE_BYTES, + SparkBinPackDataRewriter.MIN_FILE_SIZE_BYTES, + SparkBinPackDataRewriter.MAX_FILE_SIZE_BYTES, + SparkBinPackDataRewriter.MIN_INPUT_FILES, + SparkBinPackDataRewriter.REWRITE_ALL, + SparkBinPackDataRewriter.MAX_FILE_GROUP_SIZE_BYTES, + SparkBinPackDataRewriter.DELETE_FILE_THRESHOLD)); + } + + @Test + public void testSortDataValidOptions() { + Table table = catalog.createTable(TABLE_IDENT, SCHEMA); + SparkSortDataRewriter rewriter = new SparkSortDataRewriter(spark, table, SORT_ORDER); + + assertThat(rewriter.validOptions()) + .as("Rewriter must report all supported options") + .isEqualTo( + ImmutableSet.of( + SparkSortDataRewriter.SHUFFLE_PARTITIONS_PER_FILE, + SparkSortDataRewriter.TARGET_FILE_SIZE_BYTES, + SparkSortDataRewriter.MIN_FILE_SIZE_BYTES, + SparkSortDataRewriter.MAX_FILE_SIZE_BYTES, + SparkSortDataRewriter.MIN_INPUT_FILES, + SparkSortDataRewriter.REWRITE_ALL, + SparkSortDataRewriter.MAX_FILE_GROUP_SIZE_BYTES, + SparkSortDataRewriter.DELETE_FILE_THRESHOLD, + SparkSortDataRewriter.COMPRESSION_FACTOR)); + } + + @Test + public void testZOrderDataValidOptions() { + Table table = catalog.createTable(TABLE_IDENT, SCHEMA); + ImmutableList zOrderCols = ImmutableList.of("id"); + SparkZOrderDataRewriter rewriter = new SparkZOrderDataRewriter(spark, table, zOrderCols); + + assertThat(rewriter.validOptions()) + .as("Rewriter must report all supported options") + .isEqualTo( + ImmutableSet.of( + SparkZOrderDataRewriter.SHUFFLE_PARTITIONS_PER_FILE, + SparkZOrderDataRewriter.TARGET_FILE_SIZE_BYTES, + SparkZOrderDataRewriter.MIN_FILE_SIZE_BYTES, + SparkZOrderDataRewriter.MAX_FILE_SIZE_BYTES, + SparkZOrderDataRewriter.MIN_INPUT_FILES, + SparkZOrderDataRewriter.REWRITE_ALL, + SparkZOrderDataRewriter.MAX_FILE_GROUP_SIZE_BYTES, + SparkZOrderDataRewriter.DELETE_FILE_THRESHOLD, + SparkZOrderDataRewriter.COMPRESSION_FACTOR, + SparkZOrderDataRewriter.MAX_OUTPUT_SIZE, + SparkZOrderDataRewriter.VAR_LENGTH_CONTRIBUTION)); + } + + @Test + public void testInvalidValuesForBinPackDataOptions() { + Table table = catalog.createTable(TABLE_IDENT, SCHEMA); + SparkBinPackDataRewriter rewriter = new SparkBinPackDataRewriter(spark, table); + + validateSizeBasedRewriterOptions(rewriter); + + Map invalidDeleteThresholdOptions = + ImmutableMap.of(SizeBasedDataRewriter.DELETE_FILE_THRESHOLD, "-1"); + assertThatThrownBy(() -> rewriter.init(invalidDeleteThresholdOptions)) + .hasMessageContaining("'delete-file-threshold' is set to -1 but must be >= 0"); + } + + @Test + public void testInvalidValuesForSortDataOptions() { + Table table = catalog.createTable(TABLE_IDENT, SCHEMA); + SparkSortDataRewriter rewriter = new SparkSortDataRewriter(spark, table, SORT_ORDER); + + validateSizeBasedRewriterOptions(rewriter); + + Map invalidDeleteThresholdOptions = + ImmutableMap.of(SizeBasedDataRewriter.DELETE_FILE_THRESHOLD, "-1"); + assertThatThrownBy(() -> rewriter.init(invalidDeleteThresholdOptions)) + .hasMessageContaining("'delete-file-threshold' is set to -1 but must be >= 0"); + + Map invalidCompressionFactorOptions = + ImmutableMap.of(SparkShufflingDataRewriter.COMPRESSION_FACTOR, "0"); + assertThatThrownBy(() -> rewriter.init(invalidCompressionFactorOptions)) + .hasMessageContaining("'compression-factor' is set to 0.0 but must be > 0"); + } + + @Test + public void testInvalidValuesForZOrderDataOptions() { + Table table = catalog.createTable(TABLE_IDENT, SCHEMA); + ImmutableList zOrderCols = ImmutableList.of("id"); + SparkZOrderDataRewriter rewriter = new SparkZOrderDataRewriter(spark, table, zOrderCols); + + validateSizeBasedRewriterOptions(rewriter); + + Map invalidDeleteThresholdOptions = + ImmutableMap.of(SizeBasedDataRewriter.DELETE_FILE_THRESHOLD, "-1"); + assertThatThrownBy(() -> rewriter.init(invalidDeleteThresholdOptions)) + .hasMessageContaining("'delete-file-threshold' is set to -1 but must be >= 0"); + + Map invalidCompressionFactorOptions = + ImmutableMap.of(SparkShufflingDataRewriter.COMPRESSION_FACTOR, "0"); + assertThatThrownBy(() -> rewriter.init(invalidCompressionFactorOptions)) + .hasMessageContaining("'compression-factor' is set to 0.0 but must be > 0"); + + Map invalidMaxOutputOptions = + ImmutableMap.of(SparkZOrderDataRewriter.MAX_OUTPUT_SIZE, "0"); + assertThatThrownBy(() -> rewriter.init(invalidMaxOutputOptions)) + .hasMessageContaining("Cannot have the interleaved ZOrder value use less than 1 byte") + .hasMessageContaining("'max-output-size' was set to 0"); + + Map invalidVarLengthContributionOptions = + ImmutableMap.of(SparkZOrderDataRewriter.VAR_LENGTH_CONTRIBUTION, "0"); + assertThatThrownBy(() -> rewriter.init(invalidVarLengthContributionOptions)) + .hasMessageContaining("Cannot use less than 1 byte for variable length types with ZOrder") + .hasMessageContaining("'var-length-contribution' was set to 0"); + } + + private void validateSizeBasedRewriterOptions(SizeBasedFileRewriter rewriter) { + Map invalidTargetSizeOptions = + ImmutableMap.of(SizeBasedFileRewriter.TARGET_FILE_SIZE_BYTES, "0"); + assertThatThrownBy(() -> rewriter.init(invalidTargetSizeOptions)) + .hasMessageContaining("'target-file-size-bytes' is set to 0 but must be > 0"); + + Map invalidMinSizeOptions = + ImmutableMap.of(SizeBasedFileRewriter.MIN_FILE_SIZE_BYTES, "-1"); + assertThatThrownBy(() -> rewriter.init(invalidMinSizeOptions)) + .hasMessageContaining("'min-file-size-bytes' is set to -1 but must be >= 0"); + + Map invalidTargetMinSizeOptions = + ImmutableMap.of( + SizeBasedFileRewriter.TARGET_FILE_SIZE_BYTES, "3", + SizeBasedFileRewriter.MIN_FILE_SIZE_BYTES, "5"); + assertThatThrownBy(() -> rewriter.init(invalidTargetMinSizeOptions)) + .hasMessageContaining("'target-file-size-bytes' (3) must be > 'min-file-size-bytes' (5)") + .hasMessageContaining("all new files will be smaller than the min threshold"); + + Map invalidTargetMaxSizeOptions = + ImmutableMap.of( + SizeBasedFileRewriter.TARGET_FILE_SIZE_BYTES, "5", + SizeBasedFileRewriter.MAX_FILE_SIZE_BYTES, "3"); + assertThatThrownBy(() -> rewriter.init(invalidTargetMaxSizeOptions)) + .hasMessageContaining("'target-file-size-bytes' (5) must be < 'max-file-size-bytes' (3)") + .hasMessageContaining("all new files will be larger than the max threshold"); + + Map invalidMinInputFilesOptions = + ImmutableMap.of(SizeBasedFileRewriter.MIN_INPUT_FILES, "0"); + assertThatThrownBy(() -> rewriter.init(invalidMinInputFilesOptions)) + .hasMessageContaining("'min-input-files' is set to 0 but must be > 0"); + + Map invalidMaxFileGroupSizeOptions = + ImmutableMap.of(SizeBasedFileRewriter.MAX_FILE_GROUP_SIZE_BYTES, "0"); + assertThatThrownBy(() -> rewriter.init(invalidMaxFileGroupSizeOptions)) + .hasMessageContaining("'max-file-group-size-bytes' is set to 0 but must be > 0"); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/AvroDataTest.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/AvroDataTest.java new file mode 100644 index 000000000000..8f90a51a6e30 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/AvroDataTest.java @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.iceberg.Schema; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.types.Types.ListType; +import org.apache.iceberg.types.Types.LongType; +import org.apache.iceberg.types.Types.MapType; +import org.apache.iceberg.types.Types.StructType; +import org.apache.spark.sql.internal.SQLConf; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +public abstract class AvroDataTest { + + protected abstract void writeAndValidate(Schema schema) throws IOException; + + protected static final StructType SUPPORTED_PRIMITIVES = + StructType.of( + required(100, "id", LongType.get()), + optional(101, "data", Types.StringType.get()), + required(102, "b", Types.BooleanType.get()), + optional(103, "i", Types.IntegerType.get()), + required(104, "l", LongType.get()), + optional(105, "f", Types.FloatType.get()), + required(106, "d", Types.DoubleType.get()), + optional(107, "date", Types.DateType.get()), + required(108, "ts", Types.TimestampType.withZone()), + required(110, "s", Types.StringType.get()), + required(111, "uuid", Types.UUIDType.get()), + required(112, "fixed", Types.FixedType.ofLength(7)), + optional(113, "bytes", Types.BinaryType.get()), + required(114, "dec_9_0", Types.DecimalType.of(9, 0)), // int encoded + required(115, "dec_11_2", Types.DecimalType.of(11, 2)), // long encoded + required(116, "dec_20_5", Types.DecimalType.of(20, 5)), // requires padding + required(117, "dec_38_10", Types.DecimalType.of(38, 10)) // Spark's maximum precision + ); + + @TempDir protected Path temp; + + @Test + public void testSimpleStruct() throws IOException { + writeAndValidate(TypeUtil.assignIncreasingFreshIds(new Schema(SUPPORTED_PRIMITIVES.fields()))); + } + + @Test + public void testStructWithRequiredFields() throws IOException { + writeAndValidate( + TypeUtil.assignIncreasingFreshIds( + new Schema( + Lists.transform(SUPPORTED_PRIMITIVES.fields(), Types.NestedField::asRequired)))); + } + + @Test + public void testStructWithOptionalFields() throws IOException { + writeAndValidate( + TypeUtil.assignIncreasingFreshIds( + new Schema( + Lists.transform(SUPPORTED_PRIMITIVES.fields(), Types.NestedField::asOptional)))); + } + + @Test + public void testNestedStruct() throws IOException { + writeAndValidate( + TypeUtil.assignIncreasingFreshIds(new Schema(required(1, "struct", SUPPORTED_PRIMITIVES)))); + } + + @Test + public void testArray() throws IOException { + Schema schema = + new Schema( + required(0, "id", LongType.get()), + optional(1, "data", ListType.ofOptional(2, Types.StringType.get()))); + + writeAndValidate(schema); + } + + @Test + public void testArrayOfStructs() throws IOException { + Schema schema = + TypeUtil.assignIncreasingFreshIds( + new Schema( + required(0, "id", LongType.get()), + optional(1, "data", ListType.ofOptional(2, SUPPORTED_PRIMITIVES)))); + + writeAndValidate(schema); + } + + @Test + public void testMap() throws IOException { + Schema schema = + new Schema( + required(0, "id", LongType.get()), + optional( + 1, + "data", + MapType.ofOptional(2, 3, Types.StringType.get(), Types.StringType.get()))); + + writeAndValidate(schema); + } + + @Test + public void testNumericMapKey() throws IOException { + Schema schema = + new Schema( + required(0, "id", LongType.get()), + optional( + 1, "data", MapType.ofOptional(2, 3, Types.LongType.get(), Types.StringType.get()))); + + writeAndValidate(schema); + } + + @Test + public void testComplexMapKey() throws IOException { + Schema schema = + new Schema( + required(0, "id", LongType.get()), + optional( + 1, + "data", + MapType.ofOptional( + 2, + 3, + Types.StructType.of( + required(4, "i", Types.IntegerType.get()), + optional(5, "s", Types.StringType.get())), + Types.StringType.get()))); + + writeAndValidate(schema); + } + + @Test + public void testMapOfStructs() throws IOException { + Schema schema = + TypeUtil.assignIncreasingFreshIds( + new Schema( + required(0, "id", LongType.get()), + optional( + 1, + "data", + MapType.ofOptional(2, 3, Types.StringType.get(), SUPPORTED_PRIMITIVES)))); + + writeAndValidate(schema); + } + + @Test + public void testMixedTypes() throws IOException { + StructType structType = + StructType.of( + required(0, "id", LongType.get()), + optional( + 1, + "list_of_maps", + ListType.ofOptional( + 2, MapType.ofOptional(3, 4, Types.StringType.get(), SUPPORTED_PRIMITIVES))), + optional( + 5, + "map_of_lists", + MapType.ofOptional( + 6, 7, Types.StringType.get(), ListType.ofOptional(8, SUPPORTED_PRIMITIVES))), + required( + 9, + "list_of_lists", + ListType.ofOptional(10, ListType.ofOptional(11, SUPPORTED_PRIMITIVES))), + required( + 12, + "map_of_maps", + MapType.ofOptional( + 13, + 14, + Types.StringType.get(), + MapType.ofOptional(15, 16, Types.StringType.get(), SUPPORTED_PRIMITIVES))), + required( + 17, + "list_of_struct_of_nested_types", + ListType.ofOptional( + 19, + StructType.of( + Types.NestedField.required( + 20, + "m1", + MapType.ofOptional( + 21, 22, Types.StringType.get(), SUPPORTED_PRIMITIVES)), + Types.NestedField.optional( + 23, "l1", ListType.ofRequired(24, SUPPORTED_PRIMITIVES)), + Types.NestedField.required( + 25, "l2", ListType.ofRequired(26, SUPPORTED_PRIMITIVES)), + Types.NestedField.optional( + 27, + "m2", + MapType.ofOptional( + 28, 29, Types.StringType.get(), SUPPORTED_PRIMITIVES)))))); + + Schema schema = + new Schema( + TypeUtil.assignFreshIds(structType, new AtomicInteger(0)::incrementAndGet) + .asStructType() + .fields()); + + writeAndValidate(schema); + } + + @Test + public void testTimestampWithoutZone() throws IOException { + Schema schema = + TypeUtil.assignIncreasingFreshIds( + new Schema( + required(0, "id", LongType.get()), + optional(1, "ts_without_zone", Types.TimestampType.withoutZone()))); + + writeAndValidate(schema); + } + + protected void withSQLConf(Map conf, Action action) throws IOException { + SQLConf sqlConf = SQLConf.get(); + + Map currentConfValues = Maps.newHashMap(); + conf.keySet() + .forEach( + confKey -> { + if (sqlConf.contains(confKey)) { + String currentConfValue = sqlConf.getConfString(confKey); + currentConfValues.put(confKey, currentConfValue); + } + }); + + conf.forEach( + (confKey, confValue) -> { + if (SQLConf.isStaticConfigKey(confKey)) { + throw new RuntimeException("Cannot modify the value of a static config: " + confKey); + } + sqlConf.setConfString(confKey, confValue); + }); + + try { + action.invoke(); + } finally { + conf.forEach( + (confKey, confValue) -> { + if (currentConfValues.containsKey(confKey)) { + sqlConf.setConfString(confKey, currentConfValues.get(confKey)); + } else { + sqlConf.unsetConf(confKey); + } + }); + } + } + + @FunctionalInterface + protected interface Action { + void invoke() throws IOException; + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/GenericsHelpers.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/GenericsHelpers.java new file mode 100644 index 000000000000..501b46878bd2 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/GenericsHelpers.java @@ -0,0 +1,342 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import static org.apache.iceberg.spark.SparkSchemaUtil.convert; +import static org.assertj.core.api.Assertions.assertThat; +import static scala.collection.JavaConverters.mapAsJavaMapConverter; +import static scala.collection.JavaConverters.seqAsJavaListConverter; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.sql.Timestamp; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.OffsetDateTime; +import java.time.ZoneId; +import java.time.ZoneOffset; +import java.time.temporal.ChronoUnit; +import java.util.Collection; +import java.util.Date; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.UTF8String; +import scala.collection.Seq; + +public class GenericsHelpers { + private GenericsHelpers() {} + + private static final OffsetDateTime EPOCH = Instant.ofEpochMilli(0L).atOffset(ZoneOffset.UTC); + private static final LocalDate EPOCH_DAY = EPOCH.toLocalDate(); + + public static void assertEqualsSafe(Types.StructType struct, Record expected, Row actual) { + List fields = struct.fields(); + for (int i = 0; i < fields.size(); i += 1) { + Type fieldType = fields.get(i).type(); + + Object expectedValue = expected.get(i); + Object actualValue = actual.get(i); + + assertEqualsSafe(fieldType, expectedValue, actualValue); + } + } + + private static void assertEqualsSafe( + Types.ListType list, Collection expected, List actual) { + Type elementType = list.elementType(); + List expectedElements = Lists.newArrayList(expected); + for (int i = 0; i < expectedElements.size(); i += 1) { + Object expectedValue = expectedElements.get(i); + Object actualValue = actual.get(i); + + assertEqualsSafe(elementType, expectedValue, actualValue); + } + } + + private static void assertEqualsSafe(Types.MapType map, Map expected, Map actual) { + Type keyType = map.keyType(); + Type valueType = map.valueType(); + assertThat(actual.keySet()) + .as("Should have the same number of keys") + .hasSameSizeAs(expected.keySet()); + + for (Object expectedKey : expected.keySet()) { + Object matchingKey = null; + for (Object actualKey : actual.keySet()) { + try { + assertEqualsSafe(keyType, expectedKey, actualKey); + matchingKey = actualKey; + break; + } catch (AssertionError e) { + // failed + } + } + + assertThat(matchingKey).as("Should have a matching key").isNotNull(); + assertEqualsSafe(valueType, expected.get(expectedKey), actual.get(matchingKey)); + } + } + + @SuppressWarnings("unchecked") + private static void assertEqualsSafe(Type type, Object expected, Object actual) { + if (expected == null && actual == null) { + return; + } + + switch (type.typeId()) { + case BOOLEAN: + case INTEGER: + case LONG: + case FLOAT: + case DOUBLE: + assertThat(actual).as("Primitive value should be equal to expected").isEqualTo(expected); + break; + case DATE: + assertThat(expected).as("Should expect a LocalDate").isInstanceOf(LocalDate.class); + assertThat(actual).as("Should be a Date").isInstanceOf(Date.class); + assertThat(actual.toString()) + .as("ISO-8601 date should be equal") + .isEqualTo(String.valueOf(expected)); + break; + case TIMESTAMP: + Types.TimestampType timestampType = (Types.TimestampType) type; + if (timestampType.shouldAdjustToUTC()) { + // Timestamptz + assertThat(actual).as("Should be a Timestamp").isInstanceOf(Timestamp.class); + Timestamp ts = (Timestamp) actual; + // milliseconds from nanos has already been added by getTime + OffsetDateTime actualTs = + EPOCH.plusNanos((ts.getTime() * 1_000_000) + (ts.getNanos() % 1_000_000)); + + assertThat(expected) + .as("Should expect an OffsetDateTime") + .isInstanceOf(OffsetDateTime.class); + + assertThat(actualTs).as("Timestamp should be equal").isEqualTo(expected); + } else { + // Timestamp + assertThat(actual).as("Should be a LocalDateTime").isInstanceOf(LocalDateTime.class); + + assertThat(expected) + .as("Should expect an LocalDateTime") + .isInstanceOf(LocalDateTime.class); + + assertThat(actual).as("Timestamp should be equal").isEqualTo(expected); + } + break; + case STRING: + assertThat(actual).as("Should be a String").isInstanceOf(String.class); + assertThat(actual.toString()) + .as("Strings should be equal") + .isEqualTo(String.valueOf(expected)); + break; + case UUID: + assertThat(expected).as("Should expect a UUID").isInstanceOf(UUID.class); + assertThat(actual).as("Should be a String").isInstanceOf(String.class); + assertThat(actual.toString()) + .as("UUID string representation should match") + .isEqualTo(String.valueOf(expected)); + break; + case FIXED: + assertThat(expected).as("Should expect a byte[]").isInstanceOf(byte[].class); + assertThat(actual).as("Should be a byte[]").isInstanceOf(byte[].class); + assertThat(actual).as("Bytes should match").isEqualTo(expected); + break; + case BINARY: + assertThat(expected).as("Should expect a ByteBuffer").isInstanceOf(ByteBuffer.class); + assertThat(actual).as("Should be a byte[]").isInstanceOf(byte[].class); + assertThat(actual).as("Bytes should match").isEqualTo(((ByteBuffer) expected).array()); + break; + case DECIMAL: + assertThat(expected).as("Should expect a BigDecimal").isInstanceOf(BigDecimal.class); + assertThat(actual).as("Should be a BigDecimal").isInstanceOf(BigDecimal.class); + assertThat(actual).as("BigDecimals should be equal").isEqualTo(expected); + break; + case STRUCT: + assertThat(expected).as("Should expect a Record").isInstanceOf(Record.class); + assertThat(actual).as("Should be a Row").isInstanceOf(Row.class); + assertEqualsSafe(type.asNestedType().asStructType(), (Record) expected, (Row) actual); + break; + case LIST: + assertThat(expected).as("Should expect a Collection").isInstanceOf(Collection.class); + assertThat(actual).as("Should be a Seq").isInstanceOf(Seq.class); + List asList = seqAsJavaListConverter((Seq) actual).asJava(); + assertEqualsSafe(type.asNestedType().asListType(), (Collection) expected, asList); + break; + case MAP: + assertThat(expected).as("Should expect a Collection").isInstanceOf(Map.class); + assertThat(actual).as("Should be a Map").isInstanceOf(scala.collection.Map.class); + Map asMap = + mapAsJavaMapConverter((scala.collection.Map) actual).asJava(); + assertEqualsSafe(type.asNestedType().asMapType(), (Map) expected, asMap); + break; + case TIME: + default: + throw new IllegalArgumentException("Not a supported type: " + type); + } + } + + public static void assertEqualsUnsafe( + Types.StructType struct, Record expected, InternalRow actual) { + List fields = struct.fields(); + for (int i = 0; i < fields.size(); i += 1) { + Type fieldType = fields.get(i).type(); + + Object expectedValue = expected.get(i); + Object actualValue = actual.get(i, convert(fieldType)); + + assertEqualsUnsafe(fieldType, expectedValue, actualValue); + } + } + + private static void assertEqualsUnsafe( + Types.ListType list, Collection expected, ArrayData actual) { + Type elementType = list.elementType(); + List expectedElements = Lists.newArrayList(expected); + for (int i = 0; i < expectedElements.size(); i += 1) { + Object expectedValue = expectedElements.get(i); + Object actualValue = actual.get(i, convert(elementType)); + + assertEqualsUnsafe(elementType, expectedValue, actualValue); + } + } + + private static void assertEqualsUnsafe(Types.MapType map, Map expected, MapData actual) { + Type keyType = map.keyType(); + Type valueType = map.valueType(); + + List> expectedElements = Lists.newArrayList(expected.entrySet()); + ArrayData actualKeys = actual.keyArray(); + ArrayData actualValues = actual.valueArray(); + + for (int i = 0; i < expectedElements.size(); i += 1) { + Map.Entry expectedPair = expectedElements.get(i); + Object actualKey = actualKeys.get(i, convert(keyType)); + Object actualValue = actualValues.get(i, convert(keyType)); + + assertEqualsUnsafe(keyType, expectedPair.getKey(), actualKey); + assertEqualsUnsafe(valueType, expectedPair.getValue(), actualValue); + } + } + + private static void assertEqualsUnsafe(Type type, Object expected, Object actual) { + if (expected == null && actual == null) { + return; + } + + switch (type.typeId()) { + case BOOLEAN: + case INTEGER: + case LONG: + case FLOAT: + case DOUBLE: + assertThat(actual).as("Primitive value should be equal to expected").isEqualTo(expected); + break; + case DATE: + assertThat(expected).as("Should expect a LocalDate").isInstanceOf(LocalDate.class); + int expectedDays = (int) ChronoUnit.DAYS.between(EPOCH_DAY, (LocalDate) expected); + assertThat(actual) + .as("Primitive value should be equal to expected") + .isEqualTo(expectedDays); + break; + case TIMESTAMP: + Types.TimestampType timestampType = (Types.TimestampType) type; + if (timestampType.shouldAdjustToUTC()) { + assertThat(expected) + .as("Should expect an OffsetDateTime") + .isInstanceOf(OffsetDateTime.class); + long expectedMicros = ChronoUnit.MICROS.between(EPOCH, (OffsetDateTime) expected); + assertThat(actual) + .as("Primitive value should be equal to expected") + .isEqualTo(expectedMicros); + } else { + assertThat(expected) + .as("Should expect an LocalDateTime") + .isInstanceOf(LocalDateTime.class); + long expectedMicros = + ChronoUnit.MICROS.between(EPOCH, ((LocalDateTime) expected).atZone(ZoneId.of("UTC"))); + assertThat(actual) + .as("Primitive value should be equal to expected") + .isEqualTo(expectedMicros); + } + break; + case STRING: + assertThat(actual).as("Should be a UTF8String").isInstanceOf(UTF8String.class); + assertThat(actual.toString()) + .as("Strings should be equal") + .isEqualTo(String.valueOf(expected)); + break; + case UUID: + assertThat(expected).as("Should expect a UUID").isInstanceOf(UUID.class); + assertThat(actual).as("Should be a UTF8String").isInstanceOf(UTF8String.class); + assertThat(actual.toString()) + .as("UUID string representation should match") + .isEqualTo(String.valueOf(expected)); + break; + case FIXED: + assertThat(expected).as("Should expect a byte[]").isInstanceOf(byte[].class); + assertThat(actual).as("Should be a byte[]").isInstanceOf(byte[].class); + assertThat(actual).as("Bytes should match").isEqualTo(expected); + break; + case BINARY: + assertThat(expected).as("Should expect a ByteBuffer").isInstanceOf(ByteBuffer.class); + assertThat(actual).as("Should be a byte[]").isInstanceOf(byte[].class); + assertThat(actual).as("Bytes should match").isEqualTo(((ByteBuffer) expected).array()); + break; + case DECIMAL: + assertThat(expected).as("Should expect a BigDecimal").isInstanceOf(BigDecimal.class); + assertThat(actual).as("Should be a Decimal").isInstanceOf(Decimal.class); + assertThat(((Decimal) actual).toJavaBigDecimal()) + .as("BigDecimals should be equal") + .isEqualTo(expected); + break; + case STRUCT: + assertThat(expected).as("Should expect a Record").isInstanceOf(Record.class); + assertThat(actual).as("Should be an InternalRow").isInstanceOf(InternalRow.class); + assertEqualsUnsafe( + type.asNestedType().asStructType(), (Record) expected, (InternalRow) actual); + break; + case LIST: + assertThat(expected).as("Should expect a Collection").isInstanceOf(Collection.class); + assertThat(actual).as("Should be an ArrayData").isInstanceOf(ArrayData.class); + assertEqualsUnsafe( + type.asNestedType().asListType(), (Collection) expected, (ArrayData) actual); + break; + case MAP: + assertThat(expected).as("Should expect a Map").isInstanceOf(Map.class); + assertThat(actual).as("Should be an ArrayBasedMapData").isInstanceOf(MapData.class); + assertEqualsUnsafe(type.asNestedType().asMapType(), (Map) expected, (MapData) actual); + break; + case TIME: + default: + throw new IllegalArgumentException("Not a supported type: " + type); + } + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/ParameterizedAvroDataTest.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/ParameterizedAvroDataTest.java new file mode 100644 index 000000000000..85effe7d39a7 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/ParameterizedAvroDataTest.java @@ -0,0 +1,284 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.iceberg.Schema; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.types.Types.ListType; +import org.apache.iceberg.types.Types.LongType; +import org.apache.iceberg.types.Types.MapType; +import org.apache.iceberg.types.Types.StructType; +import org.apache.spark.sql.internal.SQLConf; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.io.TempDir; + +/** + * Copy of {@link AvroDataTest} that marks tests with @{@link org.junit.jupiter.api.TestTemplate} + * instead of @{@link Test} to make the tests work in a parameterized environment. + */ +public abstract class ParameterizedAvroDataTest { + + protected abstract void writeAndValidate(Schema schema) throws IOException; + + protected static final StructType SUPPORTED_PRIMITIVES = + StructType.of( + required(100, "id", LongType.get()), + optional(101, "data", Types.StringType.get()), + required(102, "b", Types.BooleanType.get()), + optional(103, "i", Types.IntegerType.get()), + required(104, "l", LongType.get()), + optional(105, "f", Types.FloatType.get()), + required(106, "d", Types.DoubleType.get()), + optional(107, "date", Types.DateType.get()), + required(108, "ts", Types.TimestampType.withZone()), + required(110, "s", Types.StringType.get()), + required(111, "uuid", Types.UUIDType.get()), + required(112, "fixed", Types.FixedType.ofLength(7)), + optional(113, "bytes", Types.BinaryType.get()), + required(114, "dec_9_0", Types.DecimalType.of(9, 0)), // int encoded + required(115, "dec_11_2", Types.DecimalType.of(11, 2)), // long encoded + required(116, "dec_20_5", Types.DecimalType.of(20, 5)), // requires padding + required(117, "dec_38_10", Types.DecimalType.of(38, 10)) // Spark's maximum precision + ); + + @TempDir protected Path temp; + + @TestTemplate + public void testSimpleStruct() throws IOException { + writeAndValidate(TypeUtil.assignIncreasingFreshIds(new Schema(SUPPORTED_PRIMITIVES.fields()))); + } + + @TestTemplate + public void testStructWithRequiredFields() throws IOException { + writeAndValidate( + TypeUtil.assignIncreasingFreshIds( + new Schema( + Lists.transform(SUPPORTED_PRIMITIVES.fields(), Types.NestedField::asRequired)))); + } + + @TestTemplate + public void testStructWithOptionalFields() throws IOException { + writeAndValidate( + TypeUtil.assignIncreasingFreshIds( + new Schema( + Lists.transform(SUPPORTED_PRIMITIVES.fields(), Types.NestedField::asOptional)))); + } + + @TestTemplate + public void testNestedStruct() throws IOException { + writeAndValidate( + TypeUtil.assignIncreasingFreshIds(new Schema(required(1, "struct", SUPPORTED_PRIMITIVES)))); + } + + @TestTemplate + public void testArray() throws IOException { + Schema schema = + new Schema( + required(0, "id", LongType.get()), + optional(1, "data", ListType.ofOptional(2, Types.StringType.get()))); + + writeAndValidate(schema); + } + + @TestTemplate + public void testArrayOfStructs() throws IOException { + Schema schema = + TypeUtil.assignIncreasingFreshIds( + new Schema( + required(0, "id", LongType.get()), + optional(1, "data", ListType.ofOptional(2, SUPPORTED_PRIMITIVES)))); + + writeAndValidate(schema); + } + + @TestTemplate + public void testMap() throws IOException { + Schema schema = + new Schema( + required(0, "id", LongType.get()), + optional( + 1, + "data", + MapType.ofOptional(2, 3, Types.StringType.get(), Types.StringType.get()))); + + writeAndValidate(schema); + } + + @TestTemplate + public void testNumericMapKey() throws IOException { + Schema schema = + new Schema( + required(0, "id", LongType.get()), + optional(1, "data", MapType.ofOptional(2, 3, LongType.get(), Types.StringType.get()))); + + writeAndValidate(schema); + } + + @TestTemplate + public void testComplexMapKey() throws IOException { + Schema schema = + new Schema( + required(0, "id", LongType.get()), + optional( + 1, + "data", + MapType.ofOptional( + 2, + 3, + StructType.of( + required(4, "i", Types.IntegerType.get()), + optional(5, "s", Types.StringType.get())), + Types.StringType.get()))); + + writeAndValidate(schema); + } + + @TestTemplate + public void testMapOfStructs() throws IOException { + Schema schema = + TypeUtil.assignIncreasingFreshIds( + new Schema( + required(0, "id", LongType.get()), + optional( + 1, + "data", + MapType.ofOptional(2, 3, Types.StringType.get(), SUPPORTED_PRIMITIVES)))); + + writeAndValidate(schema); + } + + @TestTemplate + public void testMixedTypes() throws IOException { + StructType structType = + StructType.of( + required(0, "id", LongType.get()), + optional( + 1, + "list_of_maps", + ListType.ofOptional( + 2, MapType.ofOptional(3, 4, Types.StringType.get(), SUPPORTED_PRIMITIVES))), + optional( + 5, + "map_of_lists", + MapType.ofOptional( + 6, 7, Types.StringType.get(), ListType.ofOptional(8, SUPPORTED_PRIMITIVES))), + required( + 9, + "list_of_lists", + ListType.ofOptional(10, ListType.ofOptional(11, SUPPORTED_PRIMITIVES))), + required( + 12, + "map_of_maps", + MapType.ofOptional( + 13, + 14, + Types.StringType.get(), + MapType.ofOptional(15, 16, Types.StringType.get(), SUPPORTED_PRIMITIVES))), + required( + 17, + "list_of_struct_of_nested_types", + ListType.ofOptional( + 19, + StructType.of( + Types.NestedField.required( + 20, + "m1", + MapType.ofOptional( + 21, 22, Types.StringType.get(), SUPPORTED_PRIMITIVES)), + Types.NestedField.optional( + 23, "l1", ListType.ofRequired(24, SUPPORTED_PRIMITIVES)), + Types.NestedField.required( + 25, "l2", ListType.ofRequired(26, SUPPORTED_PRIMITIVES)), + Types.NestedField.optional( + 27, + "m2", + MapType.ofOptional( + 28, 29, Types.StringType.get(), SUPPORTED_PRIMITIVES)))))); + + Schema schema = + new Schema( + TypeUtil.assignFreshIds(structType, new AtomicInteger(0)::incrementAndGet) + .asStructType() + .fields()); + + writeAndValidate(schema); + } + + @TestTemplate + public void testTimestampWithoutZone() throws IOException { + Schema schema = + TypeUtil.assignIncreasingFreshIds( + new Schema( + required(0, "id", LongType.get()), + optional(1, "ts_without_zone", Types.TimestampType.withoutZone()))); + + writeAndValidate(schema); + } + + protected void withSQLConf(Map conf, Action action) throws IOException { + SQLConf sqlConf = SQLConf.get(); + + Map currentConfValues = Maps.newHashMap(); + conf.keySet() + .forEach( + confKey -> { + if (sqlConf.contains(confKey)) { + String currentConfValue = sqlConf.getConfString(confKey); + currentConfValues.put(confKey, currentConfValue); + } + }); + + conf.forEach( + (confKey, confValue) -> { + if (SQLConf.isStaticConfigKey(confKey)) { + throw new RuntimeException("Cannot modify the value of a static config: " + confKey); + } + sqlConf.setConfString(confKey, confValue); + }); + + try { + action.invoke(); + } finally { + conf.forEach( + (confKey, confValue) -> { + if (currentConfValues.containsKey(confKey)) { + sqlConf.setConfString(confKey, currentConfValues.get(confKey)); + } else { + sqlConf.unsetConf(confKey); + } + }); + } + } + + @FunctionalInterface + protected interface Action { + void invoke() throws IOException; + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/RandomData.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/RandomData.java new file mode 100644 index 000000000000..360b9ff20ec0 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/RandomData.java @@ -0,0 +1,370 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Random; +import java.util.Set; +import java.util.UUID; +import java.util.function.Supplier; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericData.Record; +import org.apache.iceberg.Schema; +import org.apache.iceberg.avro.AvroSchemaUtil; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.RandomUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData; +import org.apache.spark.sql.catalyst.util.GenericArrayData; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.UTF8String; + +public class RandomData { + + // Default percentage of number of values that are null for optional fields + public static final float DEFAULT_NULL_PERCENTAGE = 0.05f; + + private RandomData() {} + + public static List generateList(Schema schema, int numRecords, long seed) { + RandomDataGenerator generator = new RandomDataGenerator(schema, seed, DEFAULT_NULL_PERCENTAGE); + List records = Lists.newArrayListWithExpectedSize(numRecords); + for (int i = 0; i < numRecords; i += 1) { + records.add((Record) TypeUtil.visit(schema, generator)); + } + + return records; + } + + public static Iterable generateSpark(Schema schema, int numRecords, long seed) { + return () -> + new Iterator() { + private final SparkRandomDataGenerator generator = new SparkRandomDataGenerator(seed); + private int count = 0; + + @Override + public boolean hasNext() { + return count < numRecords; + } + + @Override + public InternalRow next() { + if (count >= numRecords) { + throw new NoSuchElementException(); + } + count += 1; + return (InternalRow) TypeUtil.visit(schema, generator); + } + }; + } + + public static Iterable generate(Schema schema, int numRecords, long seed) { + return newIterable( + () -> new RandomDataGenerator(schema, seed, DEFAULT_NULL_PERCENTAGE), schema, numRecords); + } + + public static Iterable generate( + Schema schema, int numRecords, long seed, float nullPercentage) { + return newIterable( + () -> new RandomDataGenerator(schema, seed, nullPercentage), schema, numRecords); + } + + public static Iterable generateFallbackData( + Schema schema, int numRecords, long seed, long numDictRecords) { + return newIterable( + () -> new FallbackDataGenerator(schema, seed, numDictRecords), schema, numRecords); + } + + public static Iterable generateDictionaryEncodableData( + Schema schema, int numRecords, long seed, float nullPercentage) { + return newIterable( + () -> new DictionaryEncodedDataGenerator(schema, seed, nullPercentage), schema, numRecords); + } + + private static Iterable newIterable( + Supplier newGenerator, Schema schema, int numRecords) { + return () -> + new Iterator() { + private int count = 0; + private final RandomDataGenerator generator = newGenerator.get(); + + @Override + public boolean hasNext() { + return count < numRecords; + } + + @Override + public Record next() { + if (count >= numRecords) { + throw new NoSuchElementException(); + } + count += 1; + return (Record) TypeUtil.visit(schema, generator); + } + }; + } + + private static class RandomDataGenerator extends TypeUtil.CustomOrderSchemaVisitor { + private final Map typeToSchema; + private final Random random; + // Percentage of number of values that are null for optional fields + private final float nullPercentage; + + private RandomDataGenerator(Schema schema, long seed, float nullPercentage) { + Preconditions.checkArgument( + 0.0f <= nullPercentage && nullPercentage <= 1.0f, + "Percentage needs to be in the range (0.0, 1.0)"); + this.nullPercentage = nullPercentage; + this.typeToSchema = AvroSchemaUtil.convertTypes(schema.asStruct(), "test"); + this.random = new Random(seed); + } + + @Override + public Record schema(Schema schema, Supplier structResult) { + return (Record) structResult.get(); + } + + @Override + public Record struct(Types.StructType struct, Iterable fieldResults) { + Record rec = new Record(typeToSchema.get(struct)); + + List values = Lists.newArrayList(fieldResults); + for (int i = 0; i < values.size(); i += 1) { + rec.put(i, values.get(i)); + } + + return rec; + } + + @Override + public Object field(Types.NestedField field, Supplier fieldResult) { + if (field.isOptional() && isNull()) { + return null; + } + return fieldResult.get(); + } + + private boolean isNull() { + return random.nextFloat() < nullPercentage; + } + + @Override + public Object list(Types.ListType list, Supplier elementResult) { + int numElements = random.nextInt(20); + + List result = Lists.newArrayListWithExpectedSize(numElements); + for (int i = 0; i < numElements; i += 1) { + if (list.isElementOptional() && isNull()) { + result.add(null); + } else { + result.add(elementResult.get()); + } + } + + return result; + } + + @Override + public Object map(Types.MapType map, Supplier keyResult, Supplier valueResult) { + int numEntries = random.nextInt(20); + + Map result = Maps.newLinkedHashMap(); + Set keySet = Sets.newHashSet(); + for (int i = 0; i < numEntries; i += 1) { + Object key = keyResult.get(); + // ensure no collisions + while (keySet.contains(key)) { + key = keyResult.get(); + } + + keySet.add(key); + + if (map.isValueOptional() && isNull()) { + result.put(key, null); + } else { + result.put(key, valueResult.get()); + } + } + + return result; + } + + @Override + public Object primitive(Type.PrimitiveType primitive) { + Object result = randomValue(primitive, random); + // For the primitives that Avro needs a different type than Spark, fix + // them here. + switch (primitive.typeId()) { + case FIXED: + return new GenericData.Fixed(typeToSchema.get(primitive), (byte[]) result); + case BINARY: + return ByteBuffer.wrap((byte[]) result); + case UUID: + return UUID.nameUUIDFromBytes((byte[]) result); + default: + return result; + } + } + + protected Object randomValue(Type.PrimitiveType primitive, Random rand) { + return RandomUtil.generatePrimitive(primitive, random); + } + } + + private static class SparkRandomDataGenerator extends TypeUtil.CustomOrderSchemaVisitor { + private final Random random; + + private SparkRandomDataGenerator(long seed) { + this.random = new Random(seed); + } + + @Override + public InternalRow schema(Schema schema, Supplier structResult) { + return (InternalRow) structResult.get(); + } + + @Override + public InternalRow struct(Types.StructType struct, Iterable fieldResults) { + List values = Lists.newArrayList(fieldResults); + GenericInternalRow row = new GenericInternalRow(values.size()); + for (int i = 0; i < values.size(); i += 1) { + row.update(i, values.get(i)); + } + + return row; + } + + @Override + public Object field(Types.NestedField field, Supplier fieldResult) { + // return null 5% of the time when the value is optional + if (field.isOptional() && random.nextInt(20) == 1) { + return null; + } + return fieldResult.get(); + } + + @Override + public GenericArrayData list(Types.ListType list, Supplier elementResult) { + int numElements = random.nextInt(20); + Object[] arr = new Object[numElements]; + GenericArrayData result = new GenericArrayData(arr); + + for (int i = 0; i < numElements; i += 1) { + // return null 5% of the time when the value is optional + if (list.isElementOptional() && random.nextInt(20) == 1) { + arr[i] = null; + } else { + arr[i] = elementResult.get(); + } + } + + return result; + } + + @Override + public Object map(Types.MapType map, Supplier keyResult, Supplier valueResult) { + int numEntries = random.nextInt(20); + + Object[] keysArr = new Object[numEntries]; + Object[] valuesArr = new Object[numEntries]; + GenericArrayData keys = new GenericArrayData(keysArr); + GenericArrayData values = new GenericArrayData(valuesArr); + ArrayBasedMapData result = new ArrayBasedMapData(keys, values); + + Set keySet = Sets.newHashSet(); + for (int i = 0; i < numEntries; i += 1) { + Object key = keyResult.get(); + // ensure no collisions + while (keySet.contains(key)) { + key = keyResult.get(); + } + + keySet.add(key); + + keysArr[i] = key; + // return null 5% of the time when the value is optional + if (map.isValueOptional() && random.nextInt(20) == 1) { + valuesArr[i] = null; + } else { + valuesArr[i] = valueResult.get(); + } + } + + return result; + } + + @Override + public Object primitive(Type.PrimitiveType primitive) { + Object obj = RandomUtil.generatePrimitive(primitive, random); + switch (primitive.typeId()) { + case STRING: + return UTF8String.fromString((String) obj); + case DECIMAL: + return Decimal.apply((BigDecimal) obj); + case UUID: + return UTF8String.fromString(UUID.nameUUIDFromBytes((byte[]) obj).toString()); + default: + return obj; + } + } + } + + private static class DictionaryEncodedDataGenerator extends RandomDataGenerator { + private DictionaryEncodedDataGenerator(Schema schema, long seed, float nullPercentage) { + super(schema, seed, nullPercentage); + } + + @Override + protected Object randomValue(Type.PrimitiveType primitive, Random random) { + return RandomUtil.generateDictionaryEncodablePrimitive(primitive, random); + } + } + + private static class FallbackDataGenerator extends RandomDataGenerator { + private final long dictionaryEncodedRows; + private long rowCount = 0; + + private FallbackDataGenerator(Schema schema, long seed, long numDictionaryEncoded) { + super(schema, seed, DEFAULT_NULL_PERCENTAGE); + this.dictionaryEncodedRows = numDictionaryEncoded; + } + + @Override + protected Object randomValue(Type.PrimitiveType primitive, Random rand) { + this.rowCount += 1; + if (rowCount > dictionaryEncodedRows) { + return RandomUtil.generatePrimitive(primitive, rand); + } else { + return RandomUtil.generateDictionaryEncodablePrimitive(primitive, rand); + } + } + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java new file mode 100644 index 000000000000..d64ca588f202 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java @@ -0,0 +1,843 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import static org.apache.iceberg.spark.SparkSchemaUtil.convert; +import static org.assertj.core.api.Assertions.assertThat; +import static scala.collection.JavaConverters.mapAsJavaMapConverter; +import static scala.collection.JavaConverters.seqAsJavaListConverter; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.sql.Timestamp; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.OffsetDateTime; +import java.time.ZoneOffset; +import java.time.temporal.ChronoUnit; +import java.util.Collection; +import java.util.Date; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericData.Record; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileContent; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableScan; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Streams; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.DeleteFileSet; +import org.apache.orc.storage.serde2.io.DateWritable; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.unsafe.types.UTF8String; +import scala.collection.Seq; + +public class TestHelpers { + + private TestHelpers() {} + + public static void assertEqualsSafe(Types.StructType struct, List recs, List rows) { + Streams.forEachPair( + recs.stream(), rows.stream(), (rec, row) -> assertEqualsSafe(struct, rec, row)); + } + + public static void assertEqualsSafe(Types.StructType struct, Record rec, Row row) { + List fields = struct.fields(); + for (int i = 0; i < fields.size(); i += 1) { + Type fieldType = fields.get(i).type(); + + Object expectedValue = rec.get(i); + Object actualValue = row.get(i); + + assertEqualsSafe(fieldType, expectedValue, actualValue); + } + } + + public static void assertEqualsBatch( + Types.StructType struct, Iterator expected, ColumnarBatch batch) { + for (int rowId = 0; rowId < batch.numRows(); rowId++) { + List fields = struct.fields(); + InternalRow row = batch.getRow(rowId); + Record rec = expected.next(); + for (int i = 0; i < fields.size(); i += 1) { + Type fieldType = fields.get(i).type(); + Object expectedValue = rec.get(i); + Object actualValue = row.isNullAt(i) ? null : row.get(i, convert(fieldType)); + assertEqualsUnsafe(fieldType, expectedValue, actualValue); + } + } + } + + public static void assertEqualsBatchWithRows( + Types.StructType struct, Iterator expected, ColumnarBatch batch) { + for (int rowId = 0; rowId < batch.numRows(); rowId++) { + List fields = struct.fields(); + InternalRow row = batch.getRow(rowId); + Row expectedRow = expected.next(); + for (int i = 0; i < fields.size(); i += 1) { + Type fieldType = fields.get(i).type(); + Object expectedValue = expectedRow.get(i); + Object actualValue = row.isNullAt(i) ? null : row.get(i, convert(fieldType)); + assertEqualsUnsafe(fieldType, expectedValue, actualValue); + } + } + } + + private static void assertEqualsSafe(Types.ListType list, Collection expected, List actual) { + Type elementType = list.elementType(); + List expectedElements = Lists.newArrayList(expected); + for (int i = 0; i < expectedElements.size(); i += 1) { + Object expectedValue = expectedElements.get(i); + Object actualValue = actual.get(i); + + assertEqualsSafe(elementType, expectedValue, actualValue); + } + } + + private static void assertEqualsSafe(Types.MapType map, Map expected, Map actual) { + Type keyType = map.keyType(); + Type valueType = map.valueType(); + + for (Object expectedKey : expected.keySet()) { + Object matchingKey = null; + for (Object actualKey : actual.keySet()) { + try { + assertEqualsSafe(keyType, expectedKey, actualKey); + matchingKey = actualKey; + } catch (AssertionError e) { + // failed + } + } + + assertThat(matchingKey).as("Should have a matching key").isNotNull(); + assertEqualsSafe(valueType, expected.get(expectedKey), actual.get(matchingKey)); + } + } + + private static final OffsetDateTime EPOCH = Instant.ofEpochMilli(0L).atOffset(ZoneOffset.UTC); + private static final LocalDate EPOCH_DAY = EPOCH.toLocalDate(); + + @SuppressWarnings("unchecked") + private static void assertEqualsSafe(Type type, Object expected, Object actual) { + if (expected == null && actual == null) { + return; + } + + switch (type.typeId()) { + case BOOLEAN: + case INTEGER: + case LONG: + case FLOAT: + case DOUBLE: + assertThat(actual).as("Primitive value should be equal to expected").isEqualTo(expected); + break; + case DATE: + assertThat(expected).as("Should be an int").isInstanceOf(Integer.class); + assertThat(actual).as("Should be a Date").isInstanceOf(Date.class); + LocalDate date = ChronoUnit.DAYS.addTo(EPOCH_DAY, (Integer) expected); + assertThat(actual.toString()) + .as("ISO-8601 date should be equal") + .isEqualTo(String.valueOf(date)); + break; + case TIMESTAMP: + Types.TimestampType timestampType = (Types.TimestampType) type; + + assertThat(expected).as("Should be a long").isInstanceOf(Long.class); + if (timestampType.shouldAdjustToUTC()) { + assertThat(actual).as("Should be a Timestamp").isInstanceOf(Timestamp.class); + + Timestamp ts = (Timestamp) actual; + // milliseconds from nanos has already been added by getTime + long tsMicros = (ts.getTime() * 1000) + ((ts.getNanos() / 1000) % 1000); + assertThat(tsMicros).as("Timestamp micros should be equal").isEqualTo(expected); + } else { + assertThat(actual).as("Should be a LocalDateTime").isInstanceOf(LocalDateTime.class); + + LocalDateTime ts = (LocalDateTime) actual; + Instant instant = ts.toInstant(ZoneOffset.UTC); + // milliseconds from nanos has already been added by getTime + long tsMicros = (instant.toEpochMilli() * 1000) + ((ts.getNano() / 1000) % 1000); + assertThat(tsMicros).as("Timestamp micros should be equal").isEqualTo(expected); + } + break; + case STRING: + assertThat(actual).as("Should be a String").isInstanceOf(String.class); + assertThat(actual).as("Strings should be equal").isEqualTo(String.valueOf(expected)); + break; + case UUID: + assertThat(expected).as("Should expect a UUID").isInstanceOf(UUID.class); + assertThat(actual).as("Should be a String").isInstanceOf(String.class); + assertThat(actual.toString()) + .as("UUID string representation should match") + .isEqualTo(String.valueOf(expected)); + break; + case FIXED: + assertThat(expected).as("Should expect a Fixed").isInstanceOf(GenericData.Fixed.class); + assertThat(actual).as("Should be a byte[]").isInstanceOf(byte[].class); + assertThat(actual) + .as("Bytes should match") + .isEqualTo(((GenericData.Fixed) expected).bytes()); + break; + case BINARY: + assertThat(expected).as("Should expect a ByteBuffer").isInstanceOf(ByteBuffer.class); + assertThat(actual).as("Should be a byte[]").isInstanceOf(byte[].class); + assertThat(actual).as("Bytes should match").isEqualTo(((ByteBuffer) expected).array()); + break; + case DECIMAL: + assertThat(expected).as("Should expect a BigDecimal").isInstanceOf(BigDecimal.class); + assertThat(actual).as("Should be a BigDecimal").isInstanceOf(BigDecimal.class); + assertThat(actual).as("BigDecimals should be equal").isEqualTo(expected); + break; + case STRUCT: + assertThat(expected).as("Should expect a Record").isInstanceOf(Record.class); + assertThat(actual).as("Should be a Row").isInstanceOf(Row.class); + assertEqualsSafe(type.asNestedType().asStructType(), (Record) expected, (Row) actual); + break; + case LIST: + assertThat(expected).as("Should expect a Collection").isInstanceOf(Collection.class); + assertThat(actual).as("Should be a Seq").isInstanceOf(Seq.class); + List asList = seqAsJavaListConverter((Seq) actual).asJava(); + assertEqualsSafe(type.asNestedType().asListType(), (Collection) expected, asList); + break; + case MAP: + assertThat(expected).as("Should expect a Collection").isInstanceOf(Map.class); + assertThat(actual).as("Should be a Map").isInstanceOf(scala.collection.Map.class); + Map asMap = + mapAsJavaMapConverter((scala.collection.Map) actual).asJava(); + assertEqualsSafe(type.asNestedType().asMapType(), (Map) expected, asMap); + break; + case TIME: + default: + throw new IllegalArgumentException("Not a supported type: " + type); + } + } + + public static void assertEqualsUnsafe(Types.StructType struct, Record rec, InternalRow row) { + List fields = struct.fields(); + for (int i = 0; i < fields.size(); i += 1) { + Type fieldType = fields.get(i).type(); + + Object expectedValue = rec.get(i); + Object actualValue = row.isNullAt(i) ? null : row.get(i, convert(fieldType)); + + assertEqualsUnsafe(fieldType, expectedValue, actualValue); + } + } + + private static void assertEqualsUnsafe( + Types.ListType list, Collection expected, ArrayData actual) { + Type elementType = list.elementType(); + List expectedElements = Lists.newArrayList(expected); + for (int i = 0; i < expectedElements.size(); i += 1) { + Object expectedValue = expectedElements.get(i); + Object actualValue = actual.get(i, convert(elementType)); + + assertEqualsUnsafe(elementType, expectedValue, actualValue); + } + } + + private static void assertEqualsUnsafe(Types.MapType map, Map expected, MapData actual) { + Type keyType = map.keyType(); + Type valueType = map.valueType(); + + List> expectedElements = Lists.newArrayList(expected.entrySet()); + ArrayData actualKeys = actual.keyArray(); + ArrayData actualValues = actual.valueArray(); + + for (int i = 0; i < expectedElements.size(); i += 1) { + Map.Entry expectedPair = expectedElements.get(i); + Object actualKey = actualKeys.get(i, convert(keyType)); + Object actualValue = actualValues.get(i, convert(keyType)); + + assertEqualsUnsafe(keyType, expectedPair.getKey(), actualKey); + assertEqualsUnsafe(valueType, expectedPair.getValue(), actualValue); + } + } + + private static void assertEqualsUnsafe(Type type, Object expected, Object actual) { + if (expected == null && actual == null) { + return; + } + + switch (type.typeId()) { + case LONG: + assertThat(actual).as("Should be a long").isInstanceOf(Long.class); + if (expected instanceof Integer) { + assertThat(actual).as("Values didn't match").isEqualTo(((Number) expected).longValue()); + } else { + assertThat(actual).as("Primitive value should be equal to expected").isEqualTo(expected); + } + break; + case DOUBLE: + assertThat(actual).as("Should be a double").isInstanceOf(Double.class); + if (expected instanceof Float) { + assertThat(Double.doubleToLongBits((double) actual)) + .as("Values didn't match") + .isEqualTo(Double.doubleToLongBits(((Number) expected).doubleValue())); + } else { + assertThat(actual).as("Primitive value should be equal to expected").isEqualTo(expected); + } + break; + case INTEGER: + case FLOAT: + case BOOLEAN: + case DATE: + case TIMESTAMP: + assertThat(actual).as("Primitive value should be equal to expected").isEqualTo(expected); + break; + case STRING: + assertThat(actual).as("Should be a UTF8String").isInstanceOf(UTF8String.class); + assertThat(actual.toString()).as("Strings should be equal").isEqualTo(expected); + break; + case UUID: + assertThat(expected).as("Should expect a UUID").isInstanceOf(UUID.class); + assertThat(actual).as("Should be a UTF8String").isInstanceOf(UTF8String.class); + assertThat(actual.toString()) + .as("UUID string representation should match") + .isEqualTo(String.valueOf(expected)); + break; + case FIXED: + assertThat(expected).as("Should expect a Fixed").isInstanceOf(GenericData.Fixed.class); + assertThat(actual).as("Should be a byte[]").isInstanceOf(byte[].class); + assertThat(actual) + .as("Bytes should match") + .isEqualTo(((GenericData.Fixed) expected).bytes()); + break; + case BINARY: + assertThat(expected).as("Should expect a ByteBuffer").isInstanceOf(ByteBuffer.class); + assertThat(actual).as("Should be a byte[]").isInstanceOf(byte[].class); + assertThat(actual).as("Bytes should match").isEqualTo(((ByteBuffer) expected).array()); + break; + case DECIMAL: + assertThat(expected).as("Should expect a BigDecimal").isInstanceOf(BigDecimal.class); + assertThat(actual).as("Should be a Decimal").isInstanceOf(Decimal.class); + assertThat(((Decimal) actual).toJavaBigDecimal()) + .as("BigDecimals should be equal") + .isEqualTo(expected); + break; + case STRUCT: + assertThat(expected).as("Should expect a Record").isInstanceOf(Record.class); + assertThat(actual).as("Should be an InternalRow").isInstanceOf(InternalRow.class); + assertEqualsUnsafe( + type.asNestedType().asStructType(), (Record) expected, (InternalRow) actual); + break; + case LIST: + assertThat(expected).as("Should expect a Collection").isInstanceOf(Collection.class); + assertThat(actual).as("Should be an ArrayData").isInstanceOf(ArrayData.class); + assertEqualsUnsafe( + type.asNestedType().asListType(), (Collection) expected, (ArrayData) actual); + break; + case MAP: + assertThat(expected).as("Should expect a Map").isInstanceOf(Map.class); + assertThat(actual).as("Should be an ArrayBasedMapData").isInstanceOf(MapData.class); + assertEqualsUnsafe(type.asNestedType().asMapType(), (Map) expected, (MapData) actual); + break; + case TIME: + default: + throw new IllegalArgumentException("Not a supported type: " + type); + } + } + + /** + * Check that the given InternalRow is equivalent to the Row. + * + * @param prefix context for error messages + * @param type the type of the row + * @param expected the expected value of the row + * @param actual the actual value of the row + */ + public static void assertEquals( + String prefix, Types.StructType type, InternalRow expected, Row actual) { + if (expected == null || actual == null) { + assertThat(actual).as(prefix).isEqualTo(expected); + } else { + List fields = type.fields(); + for (int c = 0; c < fields.size(); ++c) { + String fieldName = fields.get(c).name(); + Type childType = fields.get(c).type(); + switch (childType.typeId()) { + case BOOLEAN: + case INTEGER: + case LONG: + case FLOAT: + case DOUBLE: + case STRING: + case DECIMAL: + case DATE: + case TIMESTAMP: + assertThat(getPrimitiveValue(actual, c, childType)) + .as(prefix + "." + fieldName + " - " + childType) + .isEqualTo(getValue(expected, c, childType)); + break; + case UUID: + case FIXED: + case BINARY: + assertEqualBytes( + prefix + "." + fieldName, + (byte[]) getValue(expected, c, childType), + (byte[]) actual.get(c)); + break; + case STRUCT: + { + Types.StructType st = (Types.StructType) childType; + assertEquals( + prefix + "." + fieldName, + st, + expected.getStruct(c, st.fields().size()), + actual.getStruct(c)); + break; + } + case LIST: + assertEqualsLists( + prefix + "." + fieldName, + childType.asListType(), + expected.getArray(c), + toList((Seq) actual.get(c))); + break; + case MAP: + assertEqualsMaps( + prefix + "." + fieldName, + childType.asMapType(), + expected.getMap(c), + toJavaMap((scala.collection.Map) actual.getMap(c))); + break; + default: + throw new IllegalArgumentException("Unhandled type " + childType); + } + } + } + } + + private static void assertEqualsLists( + String prefix, Types.ListType type, ArrayData expected, List actual) { + if (expected == null || actual == null) { + assertThat(actual).as(prefix).isEqualTo(expected); + } else { + assertThat(actual.size()).as(prefix + "length").isEqualTo(expected.numElements()); + Type childType = type.elementType(); + for (int e = 0; e < expected.numElements(); ++e) { + switch (childType.typeId()) { + case BOOLEAN: + case INTEGER: + case LONG: + case FLOAT: + case DOUBLE: + case STRING: + case DECIMAL: + case DATE: + case TIMESTAMP: + assertThat(actual.get(e)) + .as(prefix + ".elem " + e + " - " + childType) + .isEqualTo(getValue(expected, e, childType)); + break; + case UUID: + case FIXED: + case BINARY: + assertEqualBytes( + prefix + ".elem " + e, + (byte[]) getValue(expected, e, childType), + (byte[]) actual.get(e)); + break; + case STRUCT: + { + Types.StructType st = (Types.StructType) childType; + assertEquals( + prefix + ".elem " + e, + st, + expected.getStruct(e, st.fields().size()), + (Row) actual.get(e)); + break; + } + case LIST: + assertEqualsLists( + prefix + ".elem " + e, + childType.asListType(), + expected.getArray(e), + toList((Seq) actual.get(e))); + break; + case MAP: + assertEqualsMaps( + prefix + ".elem " + e, + childType.asMapType(), + expected.getMap(e), + toJavaMap((scala.collection.Map) actual.get(e))); + break; + default: + throw new IllegalArgumentException("Unhandled type " + childType); + } + } + } + } + + private static void assertEqualsMaps( + String prefix, Types.MapType type, MapData expected, Map actual) { + if (expected == null || actual == null) { + assertThat(actual).as(prefix).isEqualTo(expected); + } else { + Type keyType = type.keyType(); + Type valueType = type.valueType(); + ArrayData expectedKeyArray = expected.keyArray(); + ArrayData expectedValueArray = expected.valueArray(); + assertThat(actual.size()).as(prefix + " length").isEqualTo(expected.numElements()); + for (int e = 0; e < expected.numElements(); ++e) { + Object expectedKey = getValue(expectedKeyArray, e, keyType); + Object actualValue = actual.get(expectedKey); + if (actualValue == null) { + assertThat(true) + .as(prefix + ".key=" + expectedKey + " has null") + .isEqualTo(expected.valueArray().isNullAt(e)); + } else { + switch (valueType.typeId()) { + case BOOLEAN: + case INTEGER: + case LONG: + case FLOAT: + case DOUBLE: + case STRING: + case DECIMAL: + case DATE: + case TIMESTAMP: + assertThat(actual.get(expectedKey)) + .as(prefix + ".key=" + expectedKey + " - " + valueType) + .isEqualTo(getValue(expectedValueArray, e, valueType)); + break; + case UUID: + case FIXED: + case BINARY: + assertEqualBytes( + prefix + ".key=" + expectedKey, + (byte[]) getValue(expectedValueArray, e, valueType), + (byte[]) actual.get(expectedKey)); + break; + case STRUCT: + { + Types.StructType st = (Types.StructType) valueType; + assertEquals( + prefix + ".key=" + expectedKey, + st, + expectedValueArray.getStruct(e, st.fields().size()), + (Row) actual.get(expectedKey)); + break; + } + case LIST: + assertEqualsLists( + prefix + ".key=" + expectedKey, + valueType.asListType(), + expectedValueArray.getArray(e), + toList((Seq) actual.get(expectedKey))); + break; + case MAP: + assertEqualsMaps( + prefix + ".key=" + expectedKey, + valueType.asMapType(), + expectedValueArray.getMap(e), + toJavaMap((scala.collection.Map) actual.get(expectedKey))); + break; + default: + throw new IllegalArgumentException("Unhandled type " + valueType); + } + } + } + } + } + + private static Object getValue(SpecializedGetters container, int ord, Type type) { + if (container.isNullAt(ord)) { + return null; + } + switch (type.typeId()) { + case BOOLEAN: + return container.getBoolean(ord); + case INTEGER: + return container.getInt(ord); + case LONG: + return container.getLong(ord); + case FLOAT: + return container.getFloat(ord); + case DOUBLE: + return container.getDouble(ord); + case STRING: + return container.getUTF8String(ord).toString(); + case BINARY: + case FIXED: + case UUID: + return container.getBinary(ord); + case DATE: + return new DateWritable(container.getInt(ord)).get(); + case TIMESTAMP: + return DateTimeUtils.toJavaTimestamp(container.getLong(ord)); + case DECIMAL: + { + Types.DecimalType dt = (Types.DecimalType) type; + return container.getDecimal(ord, dt.precision(), dt.scale()).toJavaBigDecimal(); + } + case STRUCT: + Types.StructType struct = type.asStructType(); + InternalRow internalRow = container.getStruct(ord, struct.fields().size()); + Object[] data = new Object[struct.fields().size()]; + for (int i = 0; i < data.length; i += 1) { + if (internalRow.isNullAt(i)) { + data[i] = null; + } else { + data[i] = getValue(internalRow, i, struct.fields().get(i).type()); + } + } + return new GenericRow(data); + default: + throw new IllegalArgumentException("Unhandled type " + type); + } + } + + private static Object getPrimitiveValue(Row row, int ord, Type type) { + if (row.isNullAt(ord)) { + return null; + } + switch (type.typeId()) { + case BOOLEAN: + return row.getBoolean(ord); + case INTEGER: + return row.getInt(ord); + case LONG: + return row.getLong(ord); + case FLOAT: + return row.getFloat(ord); + case DOUBLE: + return row.getDouble(ord); + case STRING: + return row.getString(ord); + case BINARY: + case FIXED: + case UUID: + return row.get(ord); + case DATE: + return row.getDate(ord); + case TIMESTAMP: + return row.getTimestamp(ord); + case DECIMAL: + return row.getDecimal(ord); + default: + throw new IllegalArgumentException("Unhandled type " + type); + } + } + + private static Map toJavaMap(scala.collection.Map map) { + return map == null ? null : mapAsJavaMapConverter(map).asJava(); + } + + private static List toList(Seq val) { + return val == null ? null : seqAsJavaListConverter(val).asJava(); + } + + private static void assertEqualBytes(String context, byte[] expected, byte[] actual) { + assertThat(actual).as(context).isEqualTo(expected); + } + + static void assertEquals(Schema schema, Object expected, Object actual) { + assertEquals("schema", convert(schema), expected, actual); + } + + private static void assertEquals(String context, DataType type, Object expected, Object actual) { + if (expected == null && actual == null) { + return; + } + + if (type instanceof StructType) { + assertThat(expected) + .as("Expected should be an InternalRow: " + context) + .isInstanceOf(InternalRow.class); + assertThat(actual) + .as("Actual should be an InternalRow: " + context) + .isInstanceOf(InternalRow.class); + assertEquals(context, (StructType) type, (InternalRow) expected, (InternalRow) actual); + + } else if (type instanceof ArrayType) { + assertThat(expected) + .as("Expected should be an ArrayData: " + context) + .isInstanceOf(ArrayData.class); + assertThat(actual) + .as("Actual should be an ArrayData: " + context) + .isInstanceOf(ArrayData.class); + assertEquals(context, (ArrayType) type, (ArrayData) expected, (ArrayData) actual); + + } else if (type instanceof MapType) { + assertThat(expected) + .as("Expected should be a MapData: " + context) + .isInstanceOf(MapData.class); + assertThat(actual).as("Actual should be a MapData: " + context).isInstanceOf(MapData.class); + assertEquals(context, (MapType) type, (MapData) expected, (MapData) actual); + + } else if (type instanceof BinaryType) { + assertEqualBytes(context, (byte[]) expected, (byte[]) actual); + } else { + assertThat(actual).as("Value should match expected: " + context).isEqualTo(expected); + } + } + + private static void assertEquals( + String context, StructType struct, InternalRow expected, InternalRow actual) { + assertThat(actual.numFields()) + .as("Should have correct number of fields") + .isEqualTo(struct.size()); + for (int i = 0; i < actual.numFields(); i += 1) { + StructField field = struct.fields()[i]; + DataType type = field.dataType(); + assertEquals( + context + "." + field.name(), + type, + expected.isNullAt(i) ? null : expected.get(i, type), + actual.isNullAt(i) ? null : actual.get(i, type)); + } + } + + private static void assertEquals( + String context, ArrayType array, ArrayData expected, ArrayData actual) { + assertThat(actual.numElements()) + .as("Should have the same number of elements") + .isEqualTo(expected.numElements()); + DataType type = array.elementType(); + for (int i = 0; i < actual.numElements(); i += 1) { + assertEquals( + context + ".element", + type, + expected.isNullAt(i) ? null : expected.get(i, type), + actual.isNullAt(i) ? null : actual.get(i, type)); + } + } + + private static void assertEquals(String context, MapType map, MapData expected, MapData actual) { + assertThat(actual.numElements()) + .as("Should have the same number of elements") + .isEqualTo(expected.numElements()); + + DataType keyType = map.keyType(); + ArrayData expectedKeys = expected.keyArray(); + ArrayData expectedValues = expected.valueArray(); + + DataType valueType = map.valueType(); + ArrayData actualKeys = actual.keyArray(); + ArrayData actualValues = actual.valueArray(); + + for (int i = 0; i < actual.numElements(); i += 1) { + assertEquals( + context + ".key", + keyType, + expectedKeys.isNullAt(i) ? null : expectedKeys.get(i, keyType), + actualKeys.isNullAt(i) ? null : actualKeys.get(i, keyType)); + assertEquals( + context + ".value", + valueType, + expectedValues.isNullAt(i) ? null : expectedValues.get(i, valueType), + actualValues.isNullAt(i) ? null : actualValues.get(i, valueType)); + } + } + + public static List dataManifests(Table table) { + return table.currentSnapshot().dataManifests(table.io()); + } + + public static List deleteManifests(Table table) { + return table.currentSnapshot().deleteManifests(table.io()); + } + + public static List dataFiles(Table table) { + return dataFiles(table, null); + } + + public static List dataFiles(Table table, String branch) { + TableScan scan = table.newScan(); + if (branch != null) { + scan = scan.useRef(branch); + } + + CloseableIterable tasks = scan.includeColumnStats().planFiles(); + return Lists.newArrayList(CloseableIterable.transform(tasks, FileScanTask::file)); + } + + public static Set deleteFiles(Table table) { + DeleteFileSet deleteFiles = DeleteFileSet.create(); + + for (FileScanTask task : table.newScan().planFiles()) { + deleteFiles.addAll(task.deletes()); + } + + return deleteFiles; + } + + public static Set reachableManifestPaths(Table table) { + return StreamSupport.stream(table.snapshots().spliterator(), false) + .flatMap(s -> s.allManifests(table.io()).stream()) + .map(ManifestFile::path) + .collect(Collectors.toSet()); + } + + public static void asMetadataRecord(GenericData.Record file, FileContent content) { + file.put(0, content.id()); + file.put(3, 0); // specId + } + + public static void asMetadataRecord(GenericData.Record file) { + file.put(0, FileContent.DATA.id()); + file.put(3, 0); // specId + } + + public static Dataset selectNonDerived(Dataset metadataTable) { + StructField[] fields = metadataTable.schema().fields(); + return metadataTable.select( + Stream.of(fields) + .filter(f -> !f.name().equals("readable_metrics")) // derived field + .map(f -> new Column(f.name())) + .toArray(Column[]::new)); + } + + public static Types.StructType nonDerivedSchema(Dataset metadataTable) { + return SparkSchemaUtil.convert(TestHelpers.selectNonDerived(metadataTable).schema()).asStruct(); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestOrcWrite.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestOrcWrite.java new file mode 100644 index 000000000000..cbaad6543076 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestOrcWrite.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.orc.ORC; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +public class TestOrcWrite { + @TempDir private Path temp; + + private static final Schema SCHEMA = + new Schema( + optional(1, "id", Types.IntegerType.get()), optional(2, "data", Types.StringType.get())); + + @Test + public void splitOffsets() throws IOException { + File testFile = File.createTempFile("junit", null, temp.toFile()); + assertThat(testFile.delete()).as("Delete should succeed").isTrue(); + + Iterable rows = RandomData.generateSpark(SCHEMA, 1, 0L); + FileAppender writer = + ORC.write(Files.localOutput(testFile)) + .createWriterFunc(SparkOrcWriter::new) + .schema(SCHEMA) + .build(); + + writer.addAll(rows); + writer.close(); + assertThat(writer.splitOffsets()).as("Split offsets not present").isNotNull(); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestParquetAvroReader.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestParquetAvroReader.java new file mode 100644 index 000000000000..3f9b4bb587ba --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestParquetAvroReader.java @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.util.Iterator; +import org.apache.avro.generic.GenericData.Record; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.parquet.ParquetAvroValueReaders; +import org.apache.iceberg.parquet.ParquetSchemaUtil; +import org.apache.iceberg.types.Types; +import org.apache.parquet.schema.MessageType; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +public class TestParquetAvroReader { + @TempDir private Path temp; + + private static final Schema COMPLEX_SCHEMA = + new Schema( + required(1, "roots", Types.LongType.get()), + optional(3, "lime", Types.ListType.ofRequired(4, Types.DoubleType.get())), + required( + 5, + "strict", + Types.StructType.of( + required(9, "tangerine", Types.StringType.get()), + optional( + 6, + "hopeful", + Types.StructType.of( + required(7, "steel", Types.FloatType.get()), + required(8, "lantern", Types.DateType.get()))), + optional(10, "vehement", Types.LongType.get()))), + optional( + 11, + "metamorphosis", + Types.MapType.ofRequired( + 12, 13, Types.StringType.get(), Types.TimestampType.withoutZone())), + required( + 14, + "winter", + Types.ListType.ofOptional( + 15, + Types.StructType.of( + optional(16, "beet", Types.DoubleType.get()), + required(17, "stamp", Types.TimeType.get()), + optional(18, "wheeze", Types.StringType.get())))), + optional( + 19, + "renovate", + Types.MapType.ofRequired( + 20, + 21, + Types.StringType.get(), + Types.StructType.of( + optional(22, "jumpy", Types.DoubleType.get()), + required(23, "koala", Types.TimeType.get()), + required(24, "couch rope", Types.IntegerType.get())))), + optional(2, "slide", Types.StringType.get()), + required(25, "foo", Types.DecimalType.of(7, 5))); + + @Disabled + public void testStructSchema() throws IOException { + Schema structSchema = + new Schema( + required(1, "circumvent", Types.LongType.get()), + optional(2, "antarctica", Types.StringType.get()), + optional(3, "fluent", Types.DoubleType.get()), + required( + 4, + "quell", + Types.StructType.of( + required(5, "operator", Types.BooleanType.get()), + optional(6, "fanta", Types.IntegerType.get()), + optional(7, "cable", Types.FloatType.get()))), + required(8, "chimney", Types.TimestampType.withZone()), + required(9, "wool", Types.DateType.get())); + + File testFile = writeTestData(structSchema, 5_000_000, 1059); + // RandomData uses the root record name "test", which must match for records to be equal + MessageType readSchema = ParquetSchemaUtil.convert(structSchema, "test"); + + long sum = 0; + long sumSq = 0; + int warmups = 2; + int trials = 10; + + for (int i = 0; i < warmups + trials; i += 1) { + // clean up as much memory as possible to avoid a large GC during the timed run + System.gc(); + + try (CloseableIterable reader = + Parquet.read(Files.localInput(testFile)) + .project(structSchema) + .createReaderFunc( + fileSchema -> ParquetAvroValueReaders.buildReader(structSchema, readSchema)) + .build()) { + long start = System.currentTimeMillis(); + long val = 0; + long count = 0; + for (Record record : reader) { + // access something to ensure the compiler doesn't optimize this away + val ^= (Long) record.get(0); + count += 1; + } + long end = System.currentTimeMillis(); + long duration = end - start; + + if (i >= warmups) { + sum += duration; + sumSq += duration * duration; + } + } + } + + double mean = ((double) sum) / trials; + double stddev = Math.sqrt((((double) sumSq) / trials) - (mean * mean)); + } + + @Disabled + public void testWithOldReadPath() throws IOException { + File testFile = writeTestData(COMPLEX_SCHEMA, 500_000, 1985); + // RandomData uses the root record name "test", which must match for records to be equal + MessageType readSchema = ParquetSchemaUtil.convert(COMPLEX_SCHEMA, "test"); + + for (int i = 0; i < 5; i += 1) { + // clean up as much memory as possible to avoid a large GC during the timed run + System.gc(); + + try (CloseableIterable reader = + Parquet.read(Files.localInput(testFile)).project(COMPLEX_SCHEMA).build()) { + long start = System.currentTimeMillis(); + long val = 0; + long count = 0; + for (Record record : reader) { + // access something to ensure the compiler doesn't optimize this away + val ^= (Long) record.get(0); + count += 1; + } + long end = System.currentTimeMillis(); + } + + // clean up as much memory as possible to avoid a large GC during the timed run + System.gc(); + + try (CloseableIterable reader = + Parquet.read(Files.localInput(testFile)) + .project(COMPLEX_SCHEMA) + .createReaderFunc( + fileSchema -> ParquetAvroValueReaders.buildReader(COMPLEX_SCHEMA, readSchema)) + .build()) { + long start = System.currentTimeMillis(); + long val = 0; + long count = 0; + for (Record record : reader) { + // access something to ensure the compiler doesn't optimize this away + val ^= (Long) record.get(0); + count += 1; + } + long end = System.currentTimeMillis(); + } + } + } + + @Test + public void testCorrectness() throws IOException { + Iterable records = RandomData.generate(COMPLEX_SCHEMA, 50_000, 34139); + + File testFile = File.createTempFile("junit", null, temp.toFile()); + assertThat(testFile.delete()).as("Delete should succeed").isTrue(); + + try (FileAppender writer = + Parquet.write(Files.localOutput(testFile)).schema(COMPLEX_SCHEMA).build()) { + writer.addAll(records); + } + + // RandomData uses the root record name "test", which must match for records to be equal + MessageType readSchema = ParquetSchemaUtil.convert(COMPLEX_SCHEMA, "test"); + + // verify that the new read path is correct + try (CloseableIterable reader = + Parquet.read(Files.localInput(testFile)) + .project(COMPLEX_SCHEMA) + .createReaderFunc( + fileSchema -> ParquetAvroValueReaders.buildReader(COMPLEX_SCHEMA, readSchema)) + .reuseContainers() + .build()) { + int recordNum = 0; + Iterator iter = records.iterator(); + for (Record actual : reader) { + Record expected = iter.next(); + assertThat(actual).as("Record " + recordNum + " should match expected").isEqualTo(expected); + recordNum += 1; + } + } + } + + private File writeTestData(Schema schema, int numRecords, int seed) throws IOException { + File testFile = File.createTempFile("junit", null, temp.toFile()); + assertThat(testFile.delete()).as("Delete should succeed").isTrue(); + + try (FileAppender writer = + Parquet.write(Files.localOutput(testFile)).schema(schema).build()) { + writer.addAll(RandomData.generate(schema, numRecords, seed)); + } + + return testFile; + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestParquetAvroWriter.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestParquetAvroWriter.java new file mode 100644 index 000000000000..83f8f7f168b1 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestParquetAvroWriter.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.util.Iterator; +import org.apache.avro.generic.GenericData.Record; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.parquet.ParquetAvroValueReaders; +import org.apache.iceberg.parquet.ParquetAvroWriter; +import org.apache.iceberg.parquet.ParquetSchemaUtil; +import org.apache.iceberg.types.Types; +import org.apache.parquet.schema.MessageType; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +public class TestParquetAvroWriter { + @TempDir private Path temp; + + private static final Schema COMPLEX_SCHEMA = + new Schema( + required(1, "roots", Types.LongType.get()), + optional(3, "lime", Types.ListType.ofRequired(4, Types.DoubleType.get())), + required( + 5, + "strict", + Types.StructType.of( + required(9, "tangerine", Types.StringType.get()), + optional( + 6, + "hopeful", + Types.StructType.of( + required(7, "steel", Types.FloatType.get()), + required(8, "lantern", Types.DateType.get()))), + optional(10, "vehement", Types.LongType.get()))), + optional( + 11, + "metamorphosis", + Types.MapType.ofRequired( + 12, 13, Types.StringType.get(), Types.TimestampType.withoutZone())), + required( + 14, + "winter", + Types.ListType.ofOptional( + 15, + Types.StructType.of( + optional(16, "beet", Types.DoubleType.get()), + required(17, "stamp", Types.TimeType.get()), + optional(18, "wheeze", Types.StringType.get())))), + optional( + 19, + "renovate", + Types.MapType.ofRequired( + 20, + 21, + Types.StringType.get(), + Types.StructType.of( + optional(22, "jumpy", Types.DoubleType.get()), + required(23, "koala", Types.TimeType.get()), + required(24, "couch rope", Types.IntegerType.get())))), + optional(2, "slide", Types.StringType.get())); + + @Test + public void testCorrectness() throws IOException { + Iterable records = RandomData.generate(COMPLEX_SCHEMA, 50_000, 34139); + + File testFile = File.createTempFile("junit", null, temp.toFile()); + assertThat(testFile.delete()).as("Delete should succeed").isTrue(); + + try (FileAppender writer = + Parquet.write(Files.localOutput(testFile)) + .schema(COMPLEX_SCHEMA) + .createWriterFunc(ParquetAvroWriter::buildWriter) + .build()) { + writer.addAll(records); + } + + // RandomData uses the root record name "test", which must match for records to be equal + MessageType readSchema = ParquetSchemaUtil.convert(COMPLEX_SCHEMA, "test"); + + // verify that the new read path is correct + try (CloseableIterable reader = + Parquet.read(Files.localInput(testFile)) + .project(COMPLEX_SCHEMA) + .createReaderFunc( + fileSchema -> ParquetAvroValueReaders.buildReader(COMPLEX_SCHEMA, readSchema)) + .build()) { + int recordNum = 0; + Iterator iter = records.iterator(); + for (Record actual : reader) { + Record expected = iter.next(); + assertThat(actual).as("Record " + recordNum + " should match expected").isEqualTo(expected); + recordNum += 1; + } + } + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroEnums.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroEnums.java new file mode 100644 index 000000000000..0dc8b48b2317 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroEnums.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.util.List; +import org.apache.avro.SchemaBuilder; +import org.apache.avro.file.DataFileWriter; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericData.Record; +import org.apache.avro.generic.GenericDatumWriter; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.avro.AvroIterable; +import org.apache.iceberg.avro.AvroSchemaUtil; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +public class TestSparkAvroEnums { + + @TempDir private Path temp; + + @Test + public void writeAndValidateEnums() throws IOException { + org.apache.avro.Schema avroSchema = + SchemaBuilder.record("root") + .fields() + .name("enumCol") + .type() + .nullable() + .enumeration("testEnum") + .symbols("SYMB1", "SYMB2") + .enumDefault("SYMB2") + .endRecord(); + + org.apache.avro.Schema enumSchema = avroSchema.getField("enumCol").schema().getTypes().get(0); + Record enumRecord1 = new GenericData.Record(avroSchema); + enumRecord1.put("enumCol", new GenericData.EnumSymbol(enumSchema, "SYMB1")); + Record enumRecord2 = new GenericData.Record(avroSchema); + enumRecord2.put("enumCol", new GenericData.EnumSymbol(enumSchema, "SYMB2")); + Record enumRecord3 = new GenericData.Record(avroSchema); // null enum + List expected = ImmutableList.of(enumRecord1, enumRecord2, enumRecord3); + + File testFile = File.createTempFile("junit", null, temp.toFile()); + assertThat(testFile.delete()).as("Delete should succeed").isTrue(); + + try (DataFileWriter writer = new DataFileWriter<>(new GenericDatumWriter<>())) { + writer.create(avroSchema, testFile); + writer.append(enumRecord1); + writer.append(enumRecord2); + writer.append(enumRecord3); + } + + Schema schema = new Schema(AvroSchemaUtil.convert(avroSchema).asStructType().fields()); + List rows; + try (AvroIterable reader = + Avro.read(Files.localInput(testFile)) + .createResolvingReader(SparkPlannedAvroReader::create) + .project(schema) + .build()) { + rows = Lists.newArrayList(reader); + } + + // Iceberg will return enums as strings, so we compare string values for the enum field + for (int i = 0; i < expected.size(); i += 1) { + String expectedEnumString = + expected.get(i).get("enumCol") == null ? null : expected.get(i).get("enumCol").toString(); + String sparkString = + rows.get(i).getUTF8String(0) == null ? null : rows.get(i).getUTF8String(0).toString(); + assertThat(sparkString).isEqualTo(expectedEnumString); + } + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroReader.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroReader.java new file mode 100644 index 000000000000..7f9bcbacf298 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroReader.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import static org.apache.iceberg.spark.data.TestHelpers.assertEqualsUnsafe; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import org.apache.avro.generic.GenericData.Record; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.avro.AvroIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.spark.sql.catalyst.InternalRow; + +public class TestSparkAvroReader extends AvroDataTest { + @Override + protected void writeAndValidate(Schema schema) throws IOException { + List expected = RandomData.generateList(schema, 100, 0L); + + File testFile = File.createTempFile("junit", null, temp.toFile()); + assertThat(testFile.delete()).as("Delete should succeed").isTrue(); + + try (FileAppender writer = + Avro.write(Files.localOutput(testFile)).schema(schema).named("test").build()) { + for (Record rec : expected) { + writer.add(rec); + } + } + + List rows; + try (AvroIterable reader = + Avro.read(Files.localInput(testFile)) + .createResolvingReader(SparkPlannedAvroReader::create) + .project(schema) + .build()) { + rows = Lists.newArrayList(reader); + } + + for (int i = 0; i < expected.size(); i += 1) { + assertEqualsUnsafe(schema.asStruct(), expected.get(i), rows.get(i)); + } + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkDateTimes.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkDateTimes.java new file mode 100644 index 000000000000..6a06f9d5836d --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkDateTimes.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.time.ZoneId; +import java.util.TimeZone; +import org.apache.iceberg.expressions.Literal; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.apache.spark.sql.catalyst.util.TimestampFormatter; +import org.junit.jupiter.api.Test; + +public class TestSparkDateTimes { + @Test + public void testSparkDate() { + // checkSparkDate("1582-10-14"); // -141428 + checkSparkDate("1582-10-15"); // first day of the gregorian calendar + checkSparkDate("1601-08-12"); + checkSparkDate("1801-07-04"); + checkSparkDate("1901-08-12"); + checkSparkDate("1969-12-31"); + checkSparkDate("1970-01-01"); + checkSparkDate("2017-12-25"); + checkSparkDate("2043-08-11"); + checkSparkDate("2111-05-03"); + checkSparkDate("2224-02-29"); + checkSparkDate("3224-10-05"); + } + + public void checkSparkDate(String dateString) { + Literal date = Literal.of(dateString).to(Types.DateType.get()); + String sparkDate = DateTimeUtils.toJavaDate(date.value()).toString(); + assertThat(sparkDate) + .as("Should be the same date (" + date.value() + ")") + .isEqualTo(dateString); + } + + @Test + public void testSparkTimestamp() { + TimeZone currentTz = TimeZone.getDefault(); + try { + TimeZone.setDefault(TimeZone.getTimeZone("UTC")); + checkSparkTimestamp("1582-10-15T15:51:08.440219+00:00", "1582-10-15 15:51:08.440219"); + checkSparkTimestamp("1970-01-01T00:00:00.000000+00:00", "1970-01-01 00:00:00"); + checkSparkTimestamp("2043-08-11T12:30:01.000001+00:00", "2043-08-11 12:30:01.000001"); + } finally { + TimeZone.setDefault(currentTz); + } + } + + public void checkSparkTimestamp(String timestampString, String sparkRepr) { + Literal ts = Literal.of(timestampString).to(Types.TimestampType.withZone()); + ZoneId zoneId = DateTimeUtils.getZoneId("UTC"); + TimestampFormatter formatter = TimestampFormatter.getFractionFormatter(zoneId); + String sparkTimestamp = formatter.format(ts.value()); + assertThat(sparkTimestamp) + .as("Should be the same timestamp (" + ts.value() + ")") + .isEqualTo(sparkRepr); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReadMetadataColumns.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReadMetadataColumns.java new file mode 100644 index 000000000000..9d725250d3d2 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReadMetadataColumns.java @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.File; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.Files; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.Schema; +import org.apache.iceberg.exceptions.RuntimeIOException; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.orc.ORC; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.data.vectorized.VectorizedSparkOrcReaders; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.orc.OrcConf; +import org.apache.orc.OrcFile; +import org.apache.orc.Reader; +import org.apache.orc.StripeInformation; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.unsafe.types.UTF8String; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestSparkOrcReadMetadataColumns { + private static final Schema DATA_SCHEMA = + new Schema( + required(100, "id", Types.LongType.get()), required(101, "data", Types.StringType.get())); + + private static final Schema PROJECTION_SCHEMA = + new Schema( + required(100, "id", Types.LongType.get()), + required(101, "data", Types.StringType.get()), + MetadataColumns.ROW_POSITION, + MetadataColumns.IS_DELETED); + + private static final int NUM_ROWS = 1000; + private static final List DATA_ROWS; + private static final List EXPECTED_ROWS; + + static { + DATA_ROWS = Lists.newArrayListWithCapacity(NUM_ROWS); + for (long i = 0; i < NUM_ROWS; i++) { + InternalRow row = new GenericInternalRow(DATA_SCHEMA.columns().size()); + row.update(0, i); + row.update(1, UTF8String.fromString("str" + i)); + DATA_ROWS.add(row); + } + + EXPECTED_ROWS = Lists.newArrayListWithCapacity(NUM_ROWS); + for (long i = 0; i < NUM_ROWS; i++) { + InternalRow row = new GenericInternalRow(PROJECTION_SCHEMA.columns().size()); + row.update(0, i); + row.update(1, UTF8String.fromString("str" + i)); + row.update(2, i); + row.update(3, false); + EXPECTED_ROWS.add(row); + } + } + + @Parameters(name = "vectorized = {0}") + public static Collection parameters() { + return Arrays.asList(false, true); + } + + @TempDir private java.nio.file.Path temp; + + @Parameter private boolean vectorized; + private File testFile; + + @BeforeEach + public void writeFile() throws IOException { + testFile = File.createTempFile("junit", null, temp.toFile()); + assertThat(testFile.delete()).as("Delete should succeed").isTrue(); + + try (FileAppender writer = + ORC.write(Files.localOutput(testFile)) + .createWriterFunc(SparkOrcWriter::new) + .schema(DATA_SCHEMA) + // write in such a way that the file contains 10 stripes each with 100 rows + .set("iceberg.orc.vectorbatch.size", "100") + .set(OrcConf.ROWS_BETWEEN_CHECKS.getAttribute(), "100") + .set(OrcConf.STRIPE_SIZE.getAttribute(), "1") + .build()) { + writer.addAll(DATA_ROWS); + } + } + + @TestTemplate + public void testReadRowNumbers() throws IOException { + readAndValidate(null, null, null, EXPECTED_ROWS); + } + + @TestTemplate + public void testReadRowNumbersWithFilter() throws IOException { + readAndValidate( + Expressions.greaterThanOrEqual("id", 500), null, null, EXPECTED_ROWS.subList(500, 1000)); + } + + @TestTemplate + public void testReadRowNumbersWithSplits() throws IOException { + Reader reader; + try { + OrcFile.ReaderOptions readerOptions = + OrcFile.readerOptions(new Configuration()).useUTCTimestamp(true); + reader = OrcFile.createReader(new Path(testFile.toString()), readerOptions); + } catch (IOException ioe) { + throw new RuntimeIOException(ioe, "Failed to open file: %s", testFile); + } + List splitOffsets = + reader.getStripes().stream().map(StripeInformation::getOffset).collect(Collectors.toList()); + List splitLengths = + reader.getStripes().stream().map(StripeInformation::getLength).collect(Collectors.toList()); + + for (int i = 0; i < 10; i++) { + readAndValidate( + null, + splitOffsets.get(i), + splitLengths.get(i), + EXPECTED_ROWS.subList(i * 100, (i + 1) * 100)); + } + } + + private void readAndValidate( + Expression filter, Long splitStart, Long splitLength, List expected) + throws IOException { + Schema projectionWithoutMetadataFields = + TypeUtil.selectNot(PROJECTION_SCHEMA, MetadataColumns.metadataFieldIds()); + CloseableIterable reader = null; + try { + ORC.ReadBuilder builder = + ORC.read(Files.localInput(testFile)).project(projectionWithoutMetadataFields); + + if (vectorized) { + builder = + builder.createBatchedReaderFunc( + readOrcSchema -> + VectorizedSparkOrcReaders.buildReader( + PROJECTION_SCHEMA, readOrcSchema, ImmutableMap.of())); + } else { + builder = + builder.createReaderFunc( + readOrcSchema -> new SparkOrcReader(PROJECTION_SCHEMA, readOrcSchema)); + } + + if (filter != null) { + builder = builder.filter(filter); + } + + if (splitStart != null && splitLength != null) { + builder = builder.split(splitStart, splitLength); + } + + if (vectorized) { + reader = batchesToRows(builder.build()); + } else { + reader = builder.build(); + } + + final Iterator actualRows = reader.iterator(); + final Iterator expectedRows = expected.iterator(); + while (expectedRows.hasNext()) { + assertThat(actualRows).as("Should have expected number of rows").hasNext(); + TestHelpers.assertEquals(PROJECTION_SCHEMA, expectedRows.next(), actualRows.next()); + } + assertThat(actualRows).as("Should not have extra rows").isExhausted(); + } finally { + if (reader != null) { + reader.close(); + } + } + } + + private CloseableIterable batchesToRows(CloseableIterable batches) { + return CloseableIterable.combine( + Iterables.concat(Iterables.transform(batches, b -> (Iterable) b::rowIterator)), + batches); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReader.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReader.java new file mode 100644 index 000000000000..5338eaf0855e --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReader.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import static org.apache.iceberg.spark.data.TestHelpers.assertEquals; +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.File; +import java.io.IOException; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.orc.ORC; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterators; +import org.apache.iceberg.spark.data.vectorized.VectorizedSparkOrcReaders; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.junit.jupiter.api.Test; + +public class TestSparkOrcReader extends AvroDataTest { + @Override + protected void writeAndValidate(Schema schema) throws IOException { + final Iterable expected = RandomData.generateSpark(schema, 100, 0L); + + writeAndValidateRecords(schema, expected); + } + + @Test + public void writeAndValidateRepeatingRecords() throws IOException { + Schema structSchema = + new Schema( + required(100, "id", Types.LongType.get()), + required(101, "data", Types.StringType.get())); + List expectedRepeating = + Collections.nCopies(100, RandomData.generateSpark(structSchema, 1, 0L).iterator().next()); + + writeAndValidateRecords(structSchema, expectedRepeating); + } + + private void writeAndValidateRecords(Schema schema, Iterable expected) + throws IOException { + File testFile = File.createTempFile("junit", null, temp.toFile()); + assertThat(testFile.delete()).as("Delete should succeed").isTrue(); + + try (FileAppender writer = + ORC.write(Files.localOutput(testFile)) + .createWriterFunc(SparkOrcWriter::new) + .schema(schema) + .build()) { + writer.addAll(expected); + } + + try (CloseableIterable reader = + ORC.read(Files.localInput(testFile)) + .project(schema) + .createReaderFunc(readOrcSchema -> new SparkOrcReader(schema, readOrcSchema)) + .build()) { + final Iterator actualRows = reader.iterator(); + final Iterator expectedRows = expected.iterator(); + while (expectedRows.hasNext()) { + assertThat(actualRows).as("Should have expected number of rows").hasNext(); + assertEquals(schema, expectedRows.next(), actualRows.next()); + } + assertThat(actualRows).as("Should not have extra rows").isExhausted(); + } + + try (CloseableIterable reader = + ORC.read(Files.localInput(testFile)) + .project(schema) + .createBatchedReaderFunc( + readOrcSchema -> + VectorizedSparkOrcReaders.buildReader(schema, readOrcSchema, ImmutableMap.of())) + .build()) { + final Iterator actualRows = batchesToRows(reader.iterator()); + final Iterator expectedRows = expected.iterator(); + while (expectedRows.hasNext()) { + assertThat(actualRows).as("Should have expected number of rows").hasNext(); + assertEquals(schema, expectedRows.next(), actualRows.next()); + } + assertThat(actualRows).as("Should not have extra rows").isExhausted(); + } + } + + private Iterator batchesToRows(Iterator batches) { + return Iterators.concat(Iterators.transform(batches, ColumnarBatch::rowIterator)); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetReadMetadataColumns.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetReadMetadataColumns.java new file mode 100644 index 000000000000..044ea3d93c0b --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetReadMetadataColumns.java @@ -0,0 +1,308 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assumptions.assumeThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.io.File; +import java.io.IOException; +import java.util.Iterator; +import java.util.List; +import java.util.Set; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.Files; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.Schema; +import org.apache.iceberg.data.DeleteFilter; +import org.apache.iceberg.deletes.PositionDeleteIndex; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.parquet.ParquetSchemaUtil; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.vectorized.VectorizedSparkParquetReaders; +import org.apache.iceberg.types.Types; +import org.apache.parquet.ParquetReadOptions; +import org.apache.parquet.hadoop.ParquetFileReader; +import org.apache.parquet.hadoop.ParquetFileWriter; +import org.apache.parquet.hadoop.metadata.BlockMetaData; +import org.apache.parquet.hadoop.util.HadoopInputFile; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.unsafe.types.UTF8String; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestSparkParquetReadMetadataColumns { + private static final Schema DATA_SCHEMA = + new Schema( + required(100, "id", Types.LongType.get()), required(101, "data", Types.StringType.get())); + + private static final Schema PROJECTION_SCHEMA = + new Schema( + required(100, "id", Types.LongType.get()), + required(101, "data", Types.StringType.get()), + MetadataColumns.ROW_POSITION, + MetadataColumns.IS_DELETED); + + private static final int NUM_ROWS = 1000; + private static final List DATA_ROWS; + private static final List EXPECTED_ROWS; + private static final int NUM_ROW_GROUPS = 10; + private static final int ROWS_PER_SPLIT = NUM_ROWS / NUM_ROW_GROUPS; + private static final int RECORDS_PER_BATCH = ROWS_PER_SPLIT / 10; + + static { + DATA_ROWS = Lists.newArrayListWithCapacity(NUM_ROWS); + for (long i = 0; i < NUM_ROWS; i += 1) { + InternalRow row = new GenericInternalRow(DATA_SCHEMA.columns().size()); + if (i >= NUM_ROWS / 2) { + row.update(0, 2 * i); + } else { + row.update(0, i); + } + row.update(1, UTF8String.fromString("str" + i)); + DATA_ROWS.add(row); + } + + EXPECTED_ROWS = Lists.newArrayListWithCapacity(NUM_ROWS); + for (long i = 0; i < NUM_ROWS; i += 1) { + InternalRow row = new GenericInternalRow(PROJECTION_SCHEMA.columns().size()); + if (i >= NUM_ROWS / 2) { + row.update(0, 2 * i); + } else { + row.update(0, i); + } + row.update(1, UTF8String.fromString("str" + i)); + row.update(2, i); + row.update(3, false); + EXPECTED_ROWS.add(row); + } + } + + @Parameters(name = "vectorized = {0}") + public static Object[][] parameters() { + return new Object[][] {new Object[] {false}, new Object[] {true}}; + } + + @TempDir protected java.nio.file.Path temp; + + @Parameter private boolean vectorized; + private File testFile; + + @BeforeEach + public void writeFile() throws IOException { + List fileSplits = Lists.newArrayList(); + StructType struct = SparkSchemaUtil.convert(DATA_SCHEMA); + Configuration conf = new Configuration(); + + testFile = File.createTempFile("junit", null, temp.toFile()); + assertThat(testFile.delete()).as("Delete should succeed").isTrue(); + ParquetFileWriter parquetFileWriter = + new ParquetFileWriter( + conf, + ParquetSchemaUtil.convert(DATA_SCHEMA, "testSchema"), + new Path(testFile.getAbsolutePath())); + + parquetFileWriter.start(); + for (int i = 0; i < NUM_ROW_GROUPS; i += 1) { + File split = File.createTempFile("junit", null, temp.toFile()); + assertThat(split.delete()).as("Delete should succeed").isTrue(); + fileSplits.add(new Path(split.getAbsolutePath())); + try (FileAppender writer = + Parquet.write(Files.localOutput(split)) + .createWriterFunc(msgType -> SparkParquetWriters.buildWriter(struct, msgType)) + .schema(DATA_SCHEMA) + .overwrite() + .build()) { + writer.addAll(DATA_ROWS.subList(i * ROWS_PER_SPLIT, (i + 1) * ROWS_PER_SPLIT)); + } + parquetFileWriter.appendFile( + HadoopInputFile.fromPath(new Path(split.getAbsolutePath()), conf)); + } + parquetFileWriter.end( + ParquetFileWriter.mergeMetadataFiles(fileSplits, conf) + .getFileMetaData() + .getKeyValueMetaData()); + } + + @TestTemplate + public void testReadRowNumbers() throws IOException { + readAndValidate(null, null, null, EXPECTED_ROWS); + } + + @TestTemplate + public void testReadRowNumbersWithDelete() throws IOException { + assumeThat(vectorized).isTrue(); + + List expectedRowsAfterDelete = Lists.newArrayList(); + EXPECTED_ROWS.forEach(row -> expectedRowsAfterDelete.add(row.copy())); + // remove row at position 98, 99, 100, 101, 102, this crosses two row groups [0, 100) and [100, + // 200) + for (int i = 98; i <= 102; i++) { + expectedRowsAfterDelete.get(i).update(3, true); + } + + Parquet.ReadBuilder builder = + Parquet.read(Files.localInput(testFile)).project(PROJECTION_SCHEMA); + + DeleteFilter deleteFilter = mock(DeleteFilter.class); + when(deleteFilter.hasPosDeletes()).thenReturn(true); + PositionDeleteIndex deletedRowPos = new CustomizedPositionDeleteIndex(); + deletedRowPos.delete(98, 103); + when(deleteFilter.deletedRowPositions()).thenReturn(deletedRowPos); + + builder.createBatchedReaderFunc( + fileSchema -> + VectorizedSparkParquetReaders.buildReader( + PROJECTION_SCHEMA, fileSchema, Maps.newHashMap(), deleteFilter)); + builder.recordsPerBatch(RECORDS_PER_BATCH); + + validate(expectedRowsAfterDelete, builder); + } + + private class CustomizedPositionDeleteIndex implements PositionDeleteIndex { + private final Set deleteIndex; + + private CustomizedPositionDeleteIndex() { + deleteIndex = Sets.newHashSet(); + } + + @Override + public void delete(long position) { + deleteIndex.add(position); + } + + @Override + public void delete(long posStart, long posEnd) { + for (long l = posStart; l < posEnd; l++) { + delete(l); + } + } + + @Override + public boolean isDeleted(long position) { + return deleteIndex.contains(position); + } + + @Override + public boolean isEmpty() { + return deleteIndex.isEmpty(); + } + } + + @TestTemplate + public void testReadRowNumbersWithFilter() throws IOException { + // current iceberg supports row group filter. + for (int i = 1; i < 5; i += 1) { + readAndValidate( + Expressions.and( + Expressions.lessThan("id", NUM_ROWS / 2), + Expressions.greaterThanOrEqual("id", i * ROWS_PER_SPLIT)), + null, + null, + EXPECTED_ROWS.subList(i * ROWS_PER_SPLIT, NUM_ROWS / 2)); + } + } + + @TestTemplate + public void testReadRowNumbersWithSplits() throws IOException { + ParquetFileReader fileReader = + new ParquetFileReader( + HadoopInputFile.fromPath(new Path(testFile.getAbsolutePath()), new Configuration()), + ParquetReadOptions.builder().build()); + List rowGroups = fileReader.getRowGroups(); + for (int i = 0; i < NUM_ROW_GROUPS; i += 1) { + readAndValidate( + null, + rowGroups.get(i).getColumns().get(0).getStartingPos(), + rowGroups.get(i).getCompressedSize(), + EXPECTED_ROWS.subList(i * ROWS_PER_SPLIT, (i + 1) * ROWS_PER_SPLIT)); + } + } + + private void readAndValidate( + Expression filter, Long splitStart, Long splitLength, List expected) + throws IOException { + Parquet.ReadBuilder builder = + Parquet.read(Files.localInput(testFile)).project(PROJECTION_SCHEMA); + + if (vectorized) { + builder.createBatchedReaderFunc( + fileSchema -> + VectorizedSparkParquetReaders.buildReader( + PROJECTION_SCHEMA, fileSchema, Maps.newHashMap(), null)); + builder.recordsPerBatch(RECORDS_PER_BATCH); + } else { + builder = + builder.createReaderFunc( + msgType -> SparkParquetReaders.buildReader(PROJECTION_SCHEMA, msgType)); + } + + if (filter != null) { + builder = builder.filter(filter); + } + + if (splitStart != null && splitLength != null) { + builder = builder.split(splitStart, splitLength); + } + + validate(expected, builder); + } + + private void validate(List expected, Parquet.ReadBuilder builder) + throws IOException { + try (CloseableIterable reader = + vectorized ? batchesToRows(builder.build()) : builder.build()) { + final Iterator actualRows = reader.iterator(); + + for (InternalRow internalRow : expected) { + assertThat(actualRows).as("Should have expected number of rows").hasNext(); + TestHelpers.assertEquals(PROJECTION_SCHEMA, internalRow, actualRows.next()); + } + + assertThat(actualRows).as("Should not have extra rows").isExhausted(); + } + } + + private CloseableIterable batchesToRows(CloseableIterable batches) { + return CloseableIterable.combine( + Iterables.concat(Iterables.transform(batches, b -> (Iterable) b::rowIterator)), + batches); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetReader.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetReader.java new file mode 100644 index 000000000000..ab0d45c3b7ca --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetReader.java @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import static org.apache.iceberg.spark.data.TestHelpers.assertEqualsUnsafe; +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assumptions.assumeThat; + +import java.io.File; +import java.io.IOException; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import org.apache.avro.generic.GenericData; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Files; +import org.apache.iceberg.MetricsConfig; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.data.IcebergGenerics; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.parquet.ParquetUtil; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.parquet.hadoop.ParquetWriter; +import org.apache.parquet.hadoop.api.WriteSupport; +import org.apache.parquet.hadoop.util.HadoopOutputFile; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.junit.jupiter.api.Test; + +public class TestSparkParquetReader extends AvroDataTest { + @Override + protected void writeAndValidate(Schema schema) throws IOException { + assumeThat( + TypeUtil.find( + schema, + type -> type.isMapType() && type.asMapType().keyType() != Types.StringType.get())) + .as("Parquet Avro cannot write non-string map keys") + .isNull(); + + List expected = RandomData.generateList(schema, 100, 0L); + + File testFile = File.createTempFile("junit", null, temp.toFile()); + assertThat(testFile.delete()).as("Delete should succeed").isTrue(); + + try (FileAppender writer = + Parquet.write(Files.localOutput(testFile)).schema(schema).named("test").build()) { + writer.addAll(expected); + } + + try (CloseableIterable reader = + Parquet.read(Files.localInput(testFile)) + .project(schema) + .createReaderFunc(type -> SparkParquetReaders.buildReader(schema, type)) + .build()) { + Iterator rows = reader.iterator(); + for (GenericData.Record record : expected) { + assertThat(rows).as("Should have expected number of rows").hasNext(); + assertEqualsUnsafe(schema.asStruct(), record, rows.next()); + } + assertThat(rows).as("Should not have extra rows").isExhausted(); + } + } + + protected List rowsFromFile(InputFile inputFile, Schema schema) throws IOException { + try (CloseableIterable reader = + Parquet.read(inputFile) + .project(schema) + .createReaderFunc(type -> SparkParquetReaders.buildReader(schema, type)) + .build()) { + return Lists.newArrayList(reader); + } + } + + protected Table tableFromInputFile(InputFile inputFile, Schema schema) throws IOException { + HadoopTables tables = new HadoopTables(); + Table table = + tables.create( + schema, + PartitionSpec.unpartitioned(), + ImmutableMap.of(), + java.nio.file.Files.createTempDirectory(temp, null).toFile().getCanonicalPath()); + + table + .newAppend() + .appendFile( + DataFiles.builder(PartitionSpec.unpartitioned()) + .withFormat(FileFormat.PARQUET) + .withInputFile(inputFile) + .withMetrics(ParquetUtil.fileMetrics(inputFile, MetricsConfig.getDefault())) + .withFileSizeInBytes(inputFile.getLength()) + .build()) + .commit(); + + return table; + } + + @Test + public void testInt96TimestampProducedBySparkIsReadCorrectly() throws IOException { + String outputFilePath = String.format("%s/%s", temp.toAbsolutePath(), "parquet_int96.parquet"); + HadoopOutputFile outputFile = + HadoopOutputFile.fromPath( + new org.apache.hadoop.fs.Path(outputFilePath), new Configuration()); + Schema schema = new Schema(required(1, "ts", Types.TimestampType.withZone())); + StructType sparkSchema = + new StructType( + new StructField[] { + new StructField("ts", DataTypes.TimestampType, true, Metadata.empty()) + }); + List rows = Lists.newArrayList(RandomData.generateSpark(schema, 10, 0L)); + + try (ParquetWriter writer = + new NativeSparkWriterBuilder(outputFile) + .set("org.apache.spark.sql.parquet.row.attributes", sparkSchema.json()) + .set("spark.sql.parquet.writeLegacyFormat", "false") + .set("spark.sql.parquet.outputTimestampType", "INT96") + .set("spark.sql.parquet.fieldId.write.enabled", "true") + .build()) { + for (InternalRow row : rows) { + writer.write(row); + } + } + + InputFile parquetInputFile = Files.localInput(outputFilePath); + List readRows = rowsFromFile(parquetInputFile, schema); + + assertThat(readRows).hasSameSizeAs(rows); + assertThat(readRows).isEqualTo(rows); + + // Now we try to import that file as an Iceberg table to make sure Iceberg can read + // Int96 end to end. + Table int96Table = tableFromInputFile(parquetInputFile, schema); + List tableRecords = Lists.newArrayList(IcebergGenerics.read(int96Table).build()); + + assertThat(tableRecords).hasSameSizeAs(rows); + + for (int i = 0; i < tableRecords.size(); i++) { + GenericsHelpers.assertEqualsUnsafe(schema.asStruct(), tableRecords.get(i), rows.get(i)); + } + } + + /** + * Native Spark ParquetWriter.Builder implementation so that we can write timestamps using Spark's + * native ParquetWriteSupport. + */ + private static class NativeSparkWriterBuilder + extends ParquetWriter.Builder { + private final Map config = Maps.newHashMap(); + + NativeSparkWriterBuilder(org.apache.parquet.io.OutputFile path) { + super(path); + } + + public NativeSparkWriterBuilder set(String property, String value) { + this.config.put(property, value); + return self(); + } + + @Override + protected NativeSparkWriterBuilder self() { + return this; + } + + @Override + protected WriteSupport getWriteSupport(Configuration configuration) { + for (Map.Entry entry : config.entrySet()) { + configuration.set(entry.getKey(), entry.getValue()); + } + + return new org.apache.spark.sql.execution.datasources.parquet.ParquetWriteSupport(); + } + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetWriter.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetWriter.java new file mode 100644 index 000000000000..73800d3cf3e0 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetWriter.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import static org.apache.iceberg.TableProperties.PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX; +import static org.apache.iceberg.TableProperties.PARQUET_BLOOM_FILTER_COLUMN_FPP_PREFIX; +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.File; +import java.io.IOException; +import java.lang.reflect.Field; +import java.nio.file.Path; +import java.util.Iterator; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.parquet.ParquetSchemaUtil; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.types.Types; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.column.ParquetProperties; +import org.apache.parquet.schema.MessageType; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +public class TestSparkParquetWriter { + @TempDir private Path temp; + + public static final Schema SCHEMA = + new Schema( + Types.NestedField.required(1, "id", Types.IntegerType.get()), + Types.NestedField.required(2, "id_long", Types.LongType.get())); + + private static final Schema COMPLEX_SCHEMA = + new Schema( + required(1, "roots", Types.LongType.get()), + optional(3, "lime", Types.ListType.ofRequired(4, Types.DoubleType.get())), + required( + 5, + "strict", + Types.StructType.of( + required(9, "tangerine", Types.StringType.get()), + optional( + 6, + "hopeful", + Types.StructType.of( + required(7, "steel", Types.FloatType.get()), + required(8, "lantern", Types.DateType.get()))), + optional(10, "vehement", Types.LongType.get()))), + optional( + 11, + "metamorphosis", + Types.MapType.ofRequired( + 12, 13, Types.StringType.get(), Types.TimestampType.withZone())), + required( + 14, + "winter", + Types.ListType.ofOptional( + 15, + Types.StructType.of( + optional(16, "beet", Types.DoubleType.get()), + required(17, "stamp", Types.FloatType.get()), + optional(18, "wheeze", Types.StringType.get())))), + optional( + 19, + "renovate", + Types.MapType.ofRequired( + 20, + 21, + Types.StringType.get(), + Types.StructType.of( + optional(22, "jumpy", Types.DoubleType.get()), + required(23, "koala", Types.UUIDType.get()), + required(24, "couch rope", Types.IntegerType.get())))), + optional(2, "slide", Types.StringType.get())); + + @Test + public void testCorrectness() throws IOException { + int numRows = 50_000; + Iterable records = RandomData.generateSpark(COMPLEX_SCHEMA, numRows, 19981); + + File testFile = File.createTempFile("junit", null, temp.toFile()); + assertThat(testFile.delete()).as("Delete should succeed").isTrue(); + + try (FileAppender writer = + Parquet.write(Files.localOutput(testFile)) + .schema(COMPLEX_SCHEMA) + .createWriterFunc( + msgType -> + SparkParquetWriters.buildWriter( + SparkSchemaUtil.convert(COMPLEX_SCHEMA), msgType)) + .build()) { + writer.addAll(records); + } + + try (CloseableIterable reader = + Parquet.read(Files.localInput(testFile)) + .project(COMPLEX_SCHEMA) + .createReaderFunc(type -> SparkParquetReaders.buildReader(COMPLEX_SCHEMA, type)) + .build()) { + Iterator expected = records.iterator(); + Iterator rows = reader.iterator(); + for (int i = 0; i < numRows; i += 1) { + assertThat(rows).as("Should have expected number of rows").hasNext(); + TestHelpers.assertEquals(COMPLEX_SCHEMA, expected.next(), rows.next()); + } + assertThat(rows).as("Should not have extra rows").isExhausted(); + } + } + + @Test + public void testFpp() throws IOException, NoSuchFieldException, IllegalAccessException { + File testFile = File.createTempFile("junit", null, temp.toFile()); + try (FileAppender writer = + Parquet.write(Files.localOutput(testFile)) + .schema(SCHEMA) + .set(PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id", "true") + .set(PARQUET_BLOOM_FILTER_COLUMN_FPP_PREFIX + "id", "0.05") + .createWriterFunc( + msgType -> + SparkParquetWriters.buildWriter(SparkSchemaUtil.convert(SCHEMA), msgType)) + .build()) { + // Using reflection to access the private 'props' field in ParquetWriter + Field propsField = writer.getClass().getDeclaredField("props"); + propsField.setAccessible(true); + ParquetProperties props = (ParquetProperties) propsField.get(writer); + MessageType parquetSchema = ParquetSchemaUtil.convert(SCHEMA, "test"); + ColumnDescriptor descriptor = parquetSchema.getColumnDescription(new String[] {"id"}); + double fpp = props.getBloomFilterFPP(descriptor).getAsDouble(); + assertThat(fpp).isEqualTo(0.05); + } + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkRecordOrcReaderWriter.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkRecordOrcReaderWriter.java new file mode 100644 index 000000000000..e9a7c1c07a5a --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkRecordOrcReaderWriter.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.File; +import java.io.IOException; +import java.math.BigDecimal; +import java.util.Iterator; +import java.util.List; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.RandomGenericData; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.data.orc.GenericOrcReader; +import org.apache.iceberg.data.orc.GenericOrcWriter; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.orc.ORC; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.jupiter.api.Test; + +public class TestSparkRecordOrcReaderWriter extends AvroDataTest { + private static final int NUM_RECORDS = 200; + + private void writeAndValidate(Schema schema, List expectedRecords) throws IOException { + final File originalFile = File.createTempFile("junit", null, temp.toFile()); + assertThat(originalFile.delete()).as("Delete should succeed").isTrue(); + + // Write few generic records into the original test file. + try (FileAppender writer = + ORC.write(Files.localOutput(originalFile)) + .createWriterFunc(GenericOrcWriter::buildWriter) + .schema(schema) + .build()) { + writer.addAll(expectedRecords); + } + + // Read into spark InternalRow from the original test file. + List internalRows = Lists.newArrayList(); + try (CloseableIterable reader = + ORC.read(Files.localInput(originalFile)) + .project(schema) + .createReaderFunc(readOrcSchema -> new SparkOrcReader(schema, readOrcSchema)) + .build()) { + reader.forEach(internalRows::add); + assertEqualsUnsafe(schema.asStruct(), expectedRecords, reader, expectedRecords.size()); + } + + final File anotherFile = File.createTempFile("junit", null, temp.toFile()); + assertThat(anotherFile.delete()).as("Delete should succeed").isTrue(); + + // Write those spark InternalRows into a new file again. + try (FileAppender writer = + ORC.write(Files.localOutput(anotherFile)) + .createWriterFunc(SparkOrcWriter::new) + .schema(schema) + .build()) { + writer.addAll(internalRows); + } + + // Check whether the InternalRows are expected records. + try (CloseableIterable reader = + ORC.read(Files.localInput(anotherFile)) + .project(schema) + .createReaderFunc(readOrcSchema -> new SparkOrcReader(schema, readOrcSchema)) + .build()) { + assertEqualsUnsafe(schema.asStruct(), expectedRecords, reader, expectedRecords.size()); + } + + // Read into iceberg GenericRecord and check again. + try (CloseableIterable reader = + ORC.read(Files.localInput(anotherFile)) + .createReaderFunc(typeDesc -> GenericOrcReader.buildReader(schema, typeDesc)) + .project(schema) + .build()) { + assertRecordEquals(expectedRecords, reader, expectedRecords.size()); + } + } + + @Override + protected void writeAndValidate(Schema schema) throws IOException { + List expectedRecords = RandomGenericData.generate(schema, NUM_RECORDS, 1992L); + writeAndValidate(schema, expectedRecords); + } + + @Test + public void testDecimalWithTrailingZero() throws IOException { + Schema schema = + new Schema( + required(1, "d1", Types.DecimalType.of(10, 2)), + required(2, "d2", Types.DecimalType.of(20, 5)), + required(3, "d3", Types.DecimalType.of(38, 20))); + + List expected = Lists.newArrayList(); + + GenericRecord record = GenericRecord.create(schema); + record.set(0, new BigDecimal("101.00")); + record.set(1, new BigDecimal("10.00E-3")); + record.set(2, new BigDecimal("1001.0000E-16")); + + expected.add(record.copy()); + + writeAndValidate(schema, expected); + } + + private static void assertRecordEquals( + Iterable expected, Iterable actual, int size) { + Iterator expectedIter = expected.iterator(); + Iterator actualIter = actual.iterator(); + for (int i = 0; i < size; i += 1) { + assertThat(expectedIter).as("Expected iterator should have more rows").hasNext(); + assertThat(actualIter).as("Actual iterator should have more rows").hasNext(); + assertThat(actualIter.next()).as("Should have same rows.").isEqualTo(expectedIter.next()); + } + assertThat(expectedIter).as("Expected iterator should not have any extra rows.").isExhausted(); + assertThat(actualIter).as("Actual iterator should not have any extra rows.").isExhausted(); + } + + private static void assertEqualsUnsafe( + Types.StructType struct, Iterable expected, Iterable actual, int size) { + Iterator expectedIter = expected.iterator(); + Iterator actualIter = actual.iterator(); + for (int i = 0; i < size; i += 1) { + assertThat(expectedIter).as("Expected iterator should have more rows").hasNext(); + assertThat(actualIter).as("Actual iterator should have more rows").hasNext(); + GenericsHelpers.assertEqualsUnsafe(struct, expectedIter.next(), actualIter.next()); + } + assertThat(expectedIter).as("Expected iterator should not have any extra rows.").isExhausted(); + assertThat(actualIter).as("Actual iterator should not have any extra rows.").isExhausted(); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestVectorizedOrcDataReader.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestVectorizedOrcDataReader.java new file mode 100644 index 000000000000..b247ef20d152 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/TestVectorizedOrcDataReader.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; +import org.apache.iceberg.Files; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.data.orc.GenericOrcWriter; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.DataWriter; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.orc.ORC; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterators; +import org.apache.iceberg.spark.data.vectorized.VectorizedSparkOrcReaders; +import org.apache.iceberg.types.Types; +import org.apache.orc.OrcConf; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.assertj.core.api.WithAssertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +public class TestVectorizedOrcDataReader implements WithAssertions { + @TempDir public static Path temp; + + private static final Schema SCHEMA = + new Schema( + Types.NestedField.required(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get()), + Types.NestedField.optional(3, "binary", Types.BinaryType.get()), + Types.NestedField.required( + 4, "array", Types.ListType.ofOptional(5, Types.IntegerType.get()))); + private static OutputFile outputFile; + + @BeforeAll + public static void createDataFile() throws IOException { + GenericRecord bufferRecord = GenericRecord.create(SCHEMA); + + ImmutableList.Builder builder = ImmutableList.builder(); + builder.add( + bufferRecord.copy( + ImmutableMap.of("id", 1L, "data", "a", "array", Collections.singletonList(1)))); + builder.add( + bufferRecord.copy(ImmutableMap.of("id", 2L, "data", "b", "array", Arrays.asList(2, 3)))); + builder.add( + bufferRecord.copy(ImmutableMap.of("id", 3L, "data", "c", "array", Arrays.asList(3, 4, 5)))); + builder.add( + bufferRecord.copy( + ImmutableMap.of("id", 4L, "data", "d", "array", Arrays.asList(4, 5, 6, 7)))); + builder.add( + bufferRecord.copy( + ImmutableMap.of("id", 5L, "data", "e", "array", Arrays.asList(5, 6, 7, 8, 9)))); + + outputFile = Files.localOutput(File.createTempFile("test", ".orc", temp.toFile())); + + try (DataWriter dataWriter = + ORC.writeData(outputFile) + .schema(SCHEMA) + .createWriterFunc(GenericOrcWriter::buildWriter) + .overwrite() + .withSpec(PartitionSpec.unpartitioned()) + .build()) { + for (Record record : builder.build()) { + dataWriter.write(record); + } + } + } + + private Iterator batchesToRows(Iterator batches) { + return Iterators.concat(Iterators.transform(batches, ColumnarBatch::rowIterator)); + } + + private void validateAllRows(Iterator rows) { + long rowCount = 0; + long expId = 1; + char expChar = 'a'; + while (rows.hasNext()) { + InternalRow row = rows.next(); + assertThat(row.getLong(0)).isEqualTo(expId); + assertThat(row.getString(1)).isEqualTo(Character.toString(expChar)); + assertThat(row.isNullAt(2)).isTrue(); + expId += 1; + expChar += 1; + rowCount += 1; + } + assertThat(rowCount).isEqualTo(5); + } + + @Test + public void testReader() throws IOException { + try (CloseableIterable reader = + ORC.read(outputFile.toInputFile()) + .project(SCHEMA) + .createBatchedReaderFunc( + readOrcSchema -> + VectorizedSparkOrcReaders.buildReader(SCHEMA, readOrcSchema, ImmutableMap.of())) + .build()) { + validateAllRows(batchesToRows(reader.iterator())); + } + } + + @Test + public void testReaderWithFilter() throws IOException { + try (CloseableIterable reader = + ORC.read(outputFile.toInputFile()) + .project(SCHEMA) + .createBatchedReaderFunc( + readOrcSchema -> + VectorizedSparkOrcReaders.buildReader(SCHEMA, readOrcSchema, ImmutableMap.of())) + .filter(Expressions.equal("id", 3L)) + .config(OrcConf.ALLOW_SARG_TO_FILTER.getAttribute(), String.valueOf(true)) + .build()) { + validateAllRows(batchesToRows(reader.iterator())); + } + } + + @Test + public void testWithFilterWithSelected() throws IOException { + try (CloseableIterable reader = + ORC.read(outputFile.toInputFile()) + .project(SCHEMA) + .createBatchedReaderFunc( + readOrcSchema -> + VectorizedSparkOrcReaders.buildReader(SCHEMA, readOrcSchema, ImmutableMap.of())) + .filter(Expressions.equal("id", 3L)) + .config(OrcConf.ALLOW_SARG_TO_FILTER.getAttribute(), String.valueOf(true)) + .config(OrcConf.READER_USE_SELECTED.getAttribute(), String.valueOf(true)) + .build()) { + Iterator rows = batchesToRows(reader.iterator()); + assertThat(rows).hasNext(); + InternalRow row = rows.next(); + assertThat(row.getLong(0)).isEqualTo(3L); + assertThat(row.getString(1)).isEqualTo("c"); + assertThat(row.getArray(3).toIntArray()).isEqualTo(new int[] {3, 4, 5}); + assertThat(rows).isExhausted(); + } + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetDictionaryEncodedVectorizedReads.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetDictionaryEncodedVectorizedReads.java new file mode 100644 index 000000000000..bc4e722bc869 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetDictionaryEncodedVectorizedReads.java @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.parquet.vectorized; + +import static org.apache.iceberg.TableProperties.PARQUET_DICT_SIZE_BYTES; +import static org.apache.iceberg.TableProperties.PARQUET_PAGE_ROW_LIMIT; +import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Iterator; +import java.util.List; +import org.apache.avro.generic.GenericData; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.base.Function; +import org.apache.iceberg.relocated.com.google.common.collect.FluentIterable; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.spark.data.vectorized.VectorizedSparkParquetReaders; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +public class TestParquetDictionaryEncodedVectorizedReads extends TestParquetVectorizedReads { + + protected static SparkSession spark = null; + + @BeforeAll + public static void startSpark() { + spark = SparkSession.builder().master("local[2]").getOrCreate(); + } + + @AfterAll + public static void stopSpark() { + if (spark != null) { + spark.stop(); + spark = null; + } + } + + @Override + Iterable generateData( + Schema schema, + int numRecords, + long seed, + float nullPercentage, + Function transform) { + Iterable data = + RandomData.generateDictionaryEncodableData(schema, numRecords, seed, nullPercentage); + return transform == IDENTITY ? data : Iterables.transform(data, transform); + } + + @Test + @Override + @Disabled // Ignored since this code path is already tested in TestParquetVectorizedReads + public void testVectorizedReadsWithNewContainers() throws IOException {} + + @Test + public void testMixedDictionaryNonDictionaryReads() throws IOException { + Schema schema = new Schema(SUPPORTED_PRIMITIVES.fields()); + File dictionaryEncodedFile = File.createTempFile("junit", null, temp.toFile()); + assertThat(dictionaryEncodedFile.delete()).as("Delete should succeed").isTrue(); + Iterable dictionaryEncodableData = + RandomData.generateDictionaryEncodableData( + schema, 10000, 0L, RandomData.DEFAULT_NULL_PERCENTAGE); + try (FileAppender writer = + getParquetWriter(schema, dictionaryEncodedFile)) { + writer.addAll(dictionaryEncodableData); + } + + File plainEncodingFile = File.createTempFile("junit", null, temp.toFile()); + assertThat(plainEncodingFile.delete()).as("Delete should succeed").isTrue(); + Iterable nonDictionaryData = + RandomData.generate(schema, 10000, 0L, RandomData.DEFAULT_NULL_PERCENTAGE); + try (FileAppender writer = getParquetWriter(schema, plainEncodingFile)) { + writer.addAll(nonDictionaryData); + } + + int rowGroupSize = PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT; + File mixedFile = File.createTempFile("junit", null, temp.toFile()); + assertThat(mixedFile.delete()).as("Delete should succeed").isTrue(); + Parquet.concat( + ImmutableList.of(dictionaryEncodedFile, plainEncodingFile, dictionaryEncodedFile), + mixedFile, + rowGroupSize, + schema, + ImmutableMap.of()); + assertRecordsMatch( + schema, + 30000, + FluentIterable.concat(dictionaryEncodableData, nonDictionaryData, dictionaryEncodableData), + mixedFile, + true, + BATCH_SIZE); + } + + @Test + public void testBinaryNotAllPagesDictionaryEncoded() throws IOException { + Schema schema = new Schema(Types.NestedField.required(1, "bytes", Types.BinaryType.get())); + File parquetFile = File.createTempFile("junit", null, temp.toFile()); + assertThat(parquetFile.delete()).as("Delete should succeed").isTrue(); + + Iterable records = RandomData.generateFallbackData(schema, 500, 0L, 100); + try (FileAppender writer = + Parquet.write(Files.localOutput(parquetFile)) + .schema(schema) + .set(PARQUET_DICT_SIZE_BYTES, "4096") + .set(PARQUET_PAGE_ROW_LIMIT, "100") + .build()) { + writer.addAll(records); + } + + // After the above, parquetFile contains one column chunk of binary data in five pages, + // the first two RLE dictionary encoded, and the remaining three plain encoded. + assertRecordsMatch(schema, 500, records, parquetFile, true, BATCH_SIZE); + } + + /** + * decimal_dict_and_plain_encoding.parquet contains one column chunk of decimal(38, 0) data in two + * pages, one RLE dictionary encoded and one plain encoded, each with 200 rows. + */ + @Test + public void testDecimalNotAllPagesDictionaryEncoded() throws Exception { + Schema schema = new Schema(Types.NestedField.required(1, "id", Types.DecimalType.of(38, 0))); + Path path = + Paths.get( + getClass() + .getClassLoader() + .getResource("decimal_dict_and_plain_encoding.parquet") + .toURI()); + + Dataset df = spark.read().parquet(path.toString()); + List expected = df.collectAsList(); + long expectedSize = df.count(); + + Parquet.ReadBuilder readBuilder = + Parquet.read(Files.localInput(path.toFile())) + .project(schema) + .createBatchedReaderFunc( + type -> + VectorizedSparkParquetReaders.buildReader( + schema, type, ImmutableMap.of(), null)); + + try (CloseableIterable batchReader = readBuilder.build()) { + Iterator expectedIter = expected.iterator(); + Iterator batches = batchReader.iterator(); + int numRowsRead = 0; + while (batches.hasNext()) { + ColumnarBatch batch = batches.next(); + numRowsRead += batch.numRows(); + TestHelpers.assertEqualsBatchWithRows(schema.asStruct(), expectedIter, batch); + } + assertThat(numRowsRead).isEqualTo(expectedSize); + } + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetDictionaryFallbackToPlainEncodingVectorizedReads.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetDictionaryFallbackToPlainEncodingVectorizedReads.java new file mode 100644 index 000000000000..e6887c6f47b5 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetDictionaryFallbackToPlainEncodingVectorizedReads.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.parquet.vectorized; + +import java.io.File; +import java.io.IOException; +import org.apache.avro.generic.GenericData; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.base.Function; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.spark.data.RandomData; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +public class TestParquetDictionaryFallbackToPlainEncodingVectorizedReads + extends TestParquetVectorizedReads { + private static final int NUM_ROWS = 1_000_000; + + @Override + protected int getNumRows() { + return NUM_ROWS; + } + + @Override + Iterable generateData( + Schema schema, + int numRecords, + long seed, + float nullPercentage, + Function transform) { + // TODO: take into account nullPercentage when generating fallback encoding data + Iterable data = RandomData.generateFallbackData(schema, numRecords, seed, numRecords / 20); + return transform == IDENTITY ? data : Iterables.transform(data, transform); + } + + @Override + FileAppender getParquetWriter(Schema schema, File testFile) + throws IOException { + return Parquet.write(Files.localOutput(testFile)) + .schema(schema) + .named("test") + .set(TableProperties.PARQUET_DICT_SIZE_BYTES, "512000") + .build(); + } + + @Test + @Override + @Disabled // Fallback encoding not triggered when data is mostly null + public void testMostlyNullsForOptionalFields() {} + + @Test + @Override + @Disabled // Ignored since this code path is already tested in TestParquetVectorizedReads + public void testVectorizedReadsWithNewContainers() throws IOException {} +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetVectorizedReads.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetVectorizedReads.java new file mode 100644 index 000000000000..5c4b216aff94 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetVectorizedReads.java @@ -0,0 +1,333 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.parquet.vectorized; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assumptions.assumeThat; + +import java.io.File; +import java.io.IOException; +import java.util.Iterator; +import org.apache.avro.generic.GenericData; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.base.Function; +import org.apache.iceberg.relocated.com.google.common.base.Strings; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.data.AvroDataTest; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.spark.data.vectorized.VectorizedSparkParquetReaders; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.parquet.column.ParquetProperties; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.Type; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +public class TestParquetVectorizedReads extends AvroDataTest { + private static final int NUM_ROWS = 200_000; + static final int BATCH_SIZE = 10_000; + + static final Function IDENTITY = record -> record; + + @Override + protected void writeAndValidate(Schema schema) throws IOException { + writeAndValidate(schema, getNumRows(), 0L, RandomData.DEFAULT_NULL_PERCENTAGE, true); + } + + private void writeAndValidate( + Schema schema, int numRecords, long seed, float nullPercentage, boolean reuseContainers) + throws IOException { + writeAndValidate( + schema, numRecords, seed, nullPercentage, reuseContainers, BATCH_SIZE, IDENTITY); + } + + private void writeAndValidate( + Schema schema, + int numRecords, + long seed, + float nullPercentage, + boolean reuseContainers, + int batchSize, + Function transform) + throws IOException { + // Write test data + assumeThat( + TypeUtil.find( + schema, + type -> type.isMapType() && type.asMapType().keyType() != Types.StringType.get())) + .as("Parquet Avro cannot write non-string map keys") + .isNull(); + + Iterable expected = + generateData(schema, numRecords, seed, nullPercentage, transform); + + // write a test parquet file using iceberg writer + File testFile = File.createTempFile("junit", null, temp.toFile()); + assertThat(testFile.delete()).as("Delete should succeed").isTrue(); + + try (FileAppender writer = getParquetWriter(schema, testFile)) { + writer.addAll(expected); + } + assertRecordsMatch(schema, numRecords, expected, testFile, reuseContainers, batchSize); + } + + protected int getNumRows() { + return NUM_ROWS; + } + + Iterable generateData( + Schema schema, + int numRecords, + long seed, + float nullPercentage, + Function transform) { + Iterable data = + RandomData.generate(schema, numRecords, seed, nullPercentage); + return transform == IDENTITY ? data : Iterables.transform(data, transform); + } + + FileAppender getParquetWriter(Schema schema, File testFile) + throws IOException { + return Parquet.write(Files.localOutput(testFile)).schema(schema).named("test").build(); + } + + FileAppender getParquetV2Writer(Schema schema, File testFile) + throws IOException { + return Parquet.write(Files.localOutput(testFile)) + .schema(schema) + .named("test") + .writerVersion(ParquetProperties.WriterVersion.PARQUET_2_0) + .build(); + } + + void assertRecordsMatch( + Schema schema, + int expectedSize, + Iterable expected, + File testFile, + boolean reuseContainers, + int batchSize) + throws IOException { + Parquet.ReadBuilder readBuilder = + Parquet.read(Files.localInput(testFile)) + .project(schema) + .recordsPerBatch(batchSize) + .createBatchedReaderFunc( + type -> + VectorizedSparkParquetReaders.buildReader( + schema, type, Maps.newHashMap(), null)); + if (reuseContainers) { + readBuilder.reuseContainers(); + } + try (CloseableIterable batchReader = readBuilder.build()) { + Iterator expectedIter = expected.iterator(); + Iterator batches = batchReader.iterator(); + int numRowsRead = 0; + while (batches.hasNext()) { + ColumnarBatch batch = batches.next(); + numRowsRead += batch.numRows(); + TestHelpers.assertEqualsBatch(schema.asStruct(), expectedIter, batch); + } + assertThat(numRowsRead).isEqualTo(expectedSize); + } + } + + @Override + @Test + @Disabled + public void testArray() {} + + @Override + @Test + @Disabled + public void testArrayOfStructs() {} + + @Override + @Test + @Disabled + public void testMap() {} + + @Override + @Test + @Disabled + public void testNumericMapKey() {} + + @Override + @Test + @Disabled + public void testComplexMapKey() {} + + @Override + @Test + @Disabled + public void testMapOfStructs() {} + + @Override + @Test + @Disabled + public void testMixedTypes() {} + + @Test + @Override + public void testNestedStruct() { + assertThatThrownBy( + () -> + VectorizedSparkParquetReaders.buildReader( + TypeUtil.assignIncreasingFreshIds( + new Schema(required(1, "struct", SUPPORTED_PRIMITIVES))), + new MessageType( + "struct", new GroupType(Type.Repetition.OPTIONAL, "struct").withId(1)), + Maps.newHashMap(), + null)) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessage("Vectorized reads are not supported yet for struct fields"); + } + + @Test + public void testMostlyNullsForOptionalFields() throws IOException { + writeAndValidate( + TypeUtil.assignIncreasingFreshIds(new Schema(SUPPORTED_PRIMITIVES.fields())), + getNumRows(), + 0L, + 0.99f, + true); + } + + @Test + public void testSettingArrowValidityVector() throws IOException { + writeAndValidate( + new Schema(Lists.transform(SUPPORTED_PRIMITIVES.fields(), Types.NestedField::asOptional)), + getNumRows(), + 0L, + RandomData.DEFAULT_NULL_PERCENTAGE, + true); + } + + @Test + public void testVectorizedReadsWithNewContainers() throws IOException { + writeAndValidate( + TypeUtil.assignIncreasingFreshIds(new Schema(SUPPORTED_PRIMITIVES.fields())), + getNumRows(), + 0L, + RandomData.DEFAULT_NULL_PERCENTAGE, + false); + } + + @Test + public void testVectorizedReadsWithReallocatedArrowBuffers() throws IOException { + // With a batch size of 2, 256 bytes are allocated in the VarCharVector. By adding strings of + // length 512, the vector will need to be reallocated for storing the batch. + writeAndValidate( + new Schema( + Lists.newArrayList( + SUPPORTED_PRIMITIVES.field("id"), SUPPORTED_PRIMITIVES.field("data"))), + 10, + 0L, + RandomData.DEFAULT_NULL_PERCENTAGE, + true, + 2, + record -> { + if (record.get("data") != null) { + record.put("data", Strings.padEnd((String) record.get("data"), 512, 'a')); + } else { + record.put("data", Strings.padEnd("", 512, 'a')); + } + return record; + }); + } + + @Test + public void testReadsForTypePromotedColumns() throws Exception { + Schema writeSchema = + new Schema( + required(100, "id", Types.LongType.get()), + optional(101, "int_data", Types.IntegerType.get()), + optional(102, "float_data", Types.FloatType.get()), + optional(103, "decimal_data", Types.DecimalType.of(10, 5))); + + File dataFile = File.createTempFile("junit", null, temp.toFile()); + assertThat(dataFile.delete()).as("Delete should succeed").isTrue(); + Iterable data = + generateData(writeSchema, 30000, 0L, RandomData.DEFAULT_NULL_PERCENTAGE, IDENTITY); + try (FileAppender writer = getParquetWriter(writeSchema, dataFile)) { + writer.addAll(data); + } + + Schema readSchema = + new Schema( + required(100, "id", Types.LongType.get()), + optional(101, "int_data", Types.LongType.get()), + optional(102, "float_data", Types.DoubleType.get()), + optional(103, "decimal_data", Types.DecimalType.of(25, 5))); + + assertRecordsMatch(readSchema, 30000, data, dataFile, false, BATCH_SIZE); + } + + @Test + public void testSupportedReadsForParquetV2() throws Exception { + // Float and double column types are written using plain encoding with Parquet V2, + // also Parquet V2 will dictionary encode decimals that use fixed length binary + // (i.e. decimals > 8 bytes) + Schema schema = + new Schema( + optional(102, "float_data", Types.FloatType.get()), + optional(103, "double_data", Types.DoubleType.get()), + optional(104, "decimal_data", Types.DecimalType.of(25, 5))); + + File dataFile = File.createTempFile("junit", null, temp.toFile()); + assertThat(dataFile.delete()).as("Delete should succeed").isTrue(); + Iterable data = + generateData(schema, 30000, 0L, RandomData.DEFAULT_NULL_PERCENTAGE, IDENTITY); + try (FileAppender writer = getParquetV2Writer(schema, dataFile)) { + writer.addAll(data); + } + assertRecordsMatch(schema, 30000, data, dataFile, false, BATCH_SIZE); + } + + @Test + public void testUnsupportedReadsForParquetV2() throws Exception { + // Longs, ints, string types etc use delta encoding and which are not supported for vectorized + // reads + Schema schema = new Schema(SUPPORTED_PRIMITIVES.fields()); + File dataFile = File.createTempFile("junit", null, temp.toFile()); + assertThat(dataFile.delete()).as("Delete should succeed").isTrue(); + Iterable data = + generateData(schema, 30000, 0L, RandomData.DEFAULT_NULL_PERCENTAGE, IDENTITY); + try (FileAppender writer = getParquetV2Writer(schema, dataFile)) { + writer.addAll(data); + } + assertThatThrownBy(() -> assertRecordsMatch(schema, 30000, data, dataFile, false, BATCH_SIZE)) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessageStartingWith("Cannot support vectorized reads for column") + .hasMessageEndingWith("Disable vectorized reads to read this table/file"); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/functions/TestSparkFunctions.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/functions/TestSparkFunctions.java new file mode 100644 index 000000000000..38ce0d4d95f1 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/functions/TestSparkFunctions.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.functions; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.apache.spark.sql.connector.catalog.functions.ScalarFunction; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.DecimalType; +import org.junit.jupiter.api.Test; + +public class TestSparkFunctions { + + @Test + public void testBuildYearsFunctionFromClass() { + UnboundFunction expected = new YearsFunction(); + + YearsFunction.DateToYearsFunction dateToYearsFunc = new YearsFunction.DateToYearsFunction(); + checkBuildFunc(dateToYearsFunc, expected); + + YearsFunction.TimestampToYearsFunction tsToYearsFunc = + new YearsFunction.TimestampToYearsFunction(); + checkBuildFunc(tsToYearsFunc, expected); + + YearsFunction.TimestampNtzToYearsFunction tsNtzToYearsFunc = + new YearsFunction.TimestampNtzToYearsFunction(); + checkBuildFunc(tsNtzToYearsFunc, expected); + } + + @Test + public void testBuildMonthsFunctionFromClass() { + UnboundFunction expected = new MonthsFunction(); + + MonthsFunction.DateToMonthsFunction dateToMonthsFunc = + new MonthsFunction.DateToMonthsFunction(); + checkBuildFunc(dateToMonthsFunc, expected); + + MonthsFunction.TimestampToMonthsFunction tsToMonthsFunc = + new MonthsFunction.TimestampToMonthsFunction(); + checkBuildFunc(tsToMonthsFunc, expected); + + MonthsFunction.TimestampNtzToMonthsFunction tsNtzToMonthsFunc = + new MonthsFunction.TimestampNtzToMonthsFunction(); + checkBuildFunc(tsNtzToMonthsFunc, expected); + } + + @Test + public void testBuildDaysFunctionFromClass() { + UnboundFunction expected = new DaysFunction(); + + DaysFunction.DateToDaysFunction dateToDaysFunc = new DaysFunction.DateToDaysFunction(); + checkBuildFunc(dateToDaysFunc, expected); + + DaysFunction.TimestampToDaysFunction tsToDaysFunc = new DaysFunction.TimestampToDaysFunction(); + checkBuildFunc(tsToDaysFunc, expected); + + DaysFunction.TimestampNtzToDaysFunction tsNtzToDaysFunc = + new DaysFunction.TimestampNtzToDaysFunction(); + checkBuildFunc(tsNtzToDaysFunc, expected); + } + + @Test + public void testBuildHoursFunctionFromClass() { + UnboundFunction expected = new HoursFunction(); + + HoursFunction.TimestampToHoursFunction tsToHoursFunc = + new HoursFunction.TimestampToHoursFunction(); + checkBuildFunc(tsToHoursFunc, expected); + + HoursFunction.TimestampNtzToHoursFunction tsNtzToHoursFunc = + new HoursFunction.TimestampNtzToHoursFunction(); + checkBuildFunc(tsNtzToHoursFunc, expected); + } + + @Test + public void testBuildBucketFunctionFromClass() { + UnboundFunction expected = new BucketFunction(); + + BucketFunction.BucketInt bucketDateFunc = new BucketFunction.BucketInt(DataTypes.DateType); + checkBuildFunc(bucketDateFunc, expected); + + BucketFunction.BucketInt bucketIntFunc = new BucketFunction.BucketInt(DataTypes.IntegerType); + checkBuildFunc(bucketIntFunc, expected); + + BucketFunction.BucketLong bucketLongFunc = new BucketFunction.BucketLong(DataTypes.LongType); + checkBuildFunc(bucketLongFunc, expected); + + BucketFunction.BucketLong bucketTsFunc = new BucketFunction.BucketLong(DataTypes.TimestampType); + checkBuildFunc(bucketTsFunc, expected); + + BucketFunction.BucketLong bucketTsNtzFunc = + new BucketFunction.BucketLong(DataTypes.TimestampNTZType); + checkBuildFunc(bucketTsNtzFunc, expected); + + BucketFunction.BucketDecimal bucketDecimalFunc = + new BucketFunction.BucketDecimal(new DecimalType()); + checkBuildFunc(bucketDecimalFunc, expected); + + BucketFunction.BucketString bucketStringFunc = new BucketFunction.BucketString(); + checkBuildFunc(bucketStringFunc, expected); + + BucketFunction.BucketBinary bucketBinary = new BucketFunction.BucketBinary(); + checkBuildFunc(bucketBinary, expected); + } + + @Test + public void testBuildTruncateFunctionFromClass() { + UnboundFunction expected = new TruncateFunction(); + + TruncateFunction.TruncateTinyInt truncateTinyIntFunc = new TruncateFunction.TruncateTinyInt(); + checkBuildFunc(truncateTinyIntFunc, expected); + + TruncateFunction.TruncateSmallInt truncateSmallIntFunc = + new TruncateFunction.TruncateSmallInt(); + checkBuildFunc(truncateSmallIntFunc, expected); + + TruncateFunction.TruncateInt truncateIntFunc = new TruncateFunction.TruncateInt(); + checkBuildFunc(truncateIntFunc, expected); + + TruncateFunction.TruncateBigInt truncateBigIntFunc = new TruncateFunction.TruncateBigInt(); + checkBuildFunc(truncateBigIntFunc, expected); + + TruncateFunction.TruncateDecimal truncateDecimalFunc = + new TruncateFunction.TruncateDecimal(10, 9); + checkBuildFunc(truncateDecimalFunc, expected); + + TruncateFunction.TruncateString truncateStringFunc = new TruncateFunction.TruncateString(); + checkBuildFunc(truncateStringFunc, expected); + + TruncateFunction.TruncateBinary truncateBinaryFunc = new TruncateFunction.TruncateBinary(); + checkBuildFunc(truncateBinaryFunc, expected); + } + + private void checkBuildFunc(ScalarFunction function, UnboundFunction expected) { + UnboundFunction actual = SparkFunctions.loadFunctionByClass(function.getClass()); + + assertThat(actual).isNotNull(); + assertThat(actual.name()).isEqualTo(expected.name()); + assertThat(actual.description()).isEqualTo(expected.description()); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/ComplexRecord.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/ComplexRecord.java new file mode 100644 index 000000000000..42e8552578cd --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/ComplexRecord.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.relocated.com.google.common.base.MoreObjects; +import org.apache.iceberg.relocated.com.google.common.base.Objects; + +public class ComplexRecord { + private long id; + private NestedRecord struct; + + public ComplexRecord() {} + + public ComplexRecord(long id, NestedRecord struct) { + this.id = id; + this.struct = struct; + } + + public long getId() { + return id; + } + + public void setId(long id) { + this.id = id; + } + + public NestedRecord getStruct() { + return struct; + } + + public void setStruct(NestedRecord struct) { + this.struct = struct; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + ComplexRecord record = (ComplexRecord) o; + return id == record.id && Objects.equal(struct, record.struct); + } + + @Override + public int hashCode() { + return Objects.hashCode(id, struct); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this).add("id", id).add("struct", struct).toString(); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/FilePathLastModifiedRecord.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/FilePathLastModifiedRecord.java new file mode 100644 index 000000000000..c62c1de6ba33 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/FilePathLastModifiedRecord.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.sql.Timestamp; +import java.util.Objects; + +public class FilePathLastModifiedRecord { + private String filePath; + private Timestamp lastModified; + + public FilePathLastModifiedRecord() {} + + public FilePathLastModifiedRecord(String filePath, Timestamp lastModified) { + this.filePath = filePath; + this.lastModified = lastModified; + } + + public String getFilePath() { + return filePath; + } + + public void setFilePath(String filePath) { + this.filePath = filePath; + } + + public Timestamp getLastModified() { + return lastModified; + } + + public void setLastModified(Timestamp lastModified) { + this.lastModified = lastModified; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + FilePathLastModifiedRecord that = (FilePathLastModifiedRecord) o; + return Objects.equals(filePath, that.filePath) + && Objects.equals(lastModified, that.lastModified); + } + + @Override + public int hashCode() { + return Objects.hash(filePath, lastModified); + } + + @Override + public String toString() { + return "FilePathLastModifiedRecord{" + + "filePath='" + + filePath + + '\'' + + ", lastModified='" + + lastModified + + '\'' + + '}'; + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/FourColumnRecord.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/FourColumnRecord.java new file mode 100644 index 000000000000..0f9529e4d105 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/FourColumnRecord.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Objects; + +public class FourColumnRecord { + private Integer c1; + private String c2; + private String c3; + private String c4; + + public FourColumnRecord() {} + + public FourColumnRecord(Integer c1, String c2, String c3, String c4) { + this.c1 = c1; + this.c2 = c2; + this.c3 = c3; + this.c4 = c4; + } + + public Integer getC1() { + return c1; + } + + public void setC1(Integer c1) { + this.c1 = c1; + } + + public String getC2() { + return c2; + } + + public void setC2(String c2) { + this.c2 = c2; + } + + public String getC3() { + return c3; + } + + public void setC3(String c3) { + this.c3 = c3; + } + + public String getC4() { + return c4; + } + + public void setC4(String c4) { + this.c4 = c4; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + FourColumnRecord that = (FourColumnRecord) o; + return Objects.equals(c1, that.c1) + && Objects.equals(c2, that.c2) + && Objects.equals(c3, that.c3) + && Objects.equals(c3, that.c4); + } + + @Override + public int hashCode() { + return Objects.hash(c1, c2, c3, c4); + } + + @Override + public String toString() { + return "ThreeColumnRecord{" + + "c1=" + + c1 + + ", c2='" + + c2 + + '\'' + + ", c3='" + + c3 + + '\'' + + ", c4='" + + c4 + + '\'' + + '}'; + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/LogMessage.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/LogMessage.java new file mode 100644 index 000000000000..875b1009c37f --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/LogMessage.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.time.Instant; +import java.util.concurrent.atomic.AtomicInteger; + +public class LogMessage { + private static final AtomicInteger ID_COUNTER = new AtomicInteger(0); + + static LogMessage debug(String date, String message) { + return new LogMessage(ID_COUNTER.getAndIncrement(), date, "DEBUG", message); + } + + static LogMessage debug(String date, String message, Instant timestamp) { + return new LogMessage(ID_COUNTER.getAndIncrement(), date, "DEBUG", message, timestamp); + } + + static LogMessage info(String date, String message) { + return new LogMessage(ID_COUNTER.getAndIncrement(), date, "INFO", message); + } + + static LogMessage info(String date, String message, Instant timestamp) { + return new LogMessage(ID_COUNTER.getAndIncrement(), date, "INFO", message, timestamp); + } + + static LogMessage error(String date, String message) { + return new LogMessage(ID_COUNTER.getAndIncrement(), date, "ERROR", message); + } + + static LogMessage error(String date, String message, Instant timestamp) { + return new LogMessage(ID_COUNTER.getAndIncrement(), date, "ERROR", message, timestamp); + } + + static LogMessage warn(String date, String message) { + return new LogMessage(ID_COUNTER.getAndIncrement(), date, "WARN", message); + } + + static LogMessage warn(String date, String message, Instant timestamp) { + return new LogMessage(ID_COUNTER.getAndIncrement(), date, "WARN", message, timestamp); + } + + private int id; + private String date; + private String level; + private String message; + private Instant timestamp; + + private LogMessage(int id, String date, String level, String message) { + this.id = id; + this.date = date; + this.level = level; + this.message = message; + } + + private LogMessage(int id, String date, String level, String message, Instant timestamp) { + this.id = id; + this.date = date; + this.level = level; + this.message = message; + this.timestamp = timestamp; + } + + public int getId() { + return id; + } + + public void setId(int id) { + this.id = id; + } + + public String getDate() { + return date; + } + + public void setDate(String date) { + this.date = date; + } + + public String getLevel() { + return level; + } + + public void setLevel(String level) { + this.level = level; + } + + public String getMessage() { + return message; + } + + public void setMessage(String message) { + this.message = message; + } + + public Instant getTimestamp() { + return timestamp; + } + + public void setTimestamp(Instant timestamp) { + this.timestamp = timestamp; + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/ManualSource.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/ManualSource.java new file mode 100644 index 000000000000..b6f172248ea9 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/ManualSource.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableProvider; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.sources.DataSourceRegister; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +public class ManualSource implements TableProvider, DataSourceRegister { + public static final String SHORT_NAME = "manual_source"; + public static final String TABLE_NAME = "TABLE_NAME"; + private static final Map TABLE_MAP = Maps.newHashMap(); + + public static void setTable(String name, Table table) { + Preconditions.checkArgument( + !TABLE_MAP.containsKey(name), "Cannot set " + name + ". It is already set"); + TABLE_MAP.put(name, table); + } + + public static void clearTables() { + TABLE_MAP.clear(); + } + + @Override + public String shortName() { + return SHORT_NAME; + } + + @Override + public StructType inferSchema(CaseInsensitiveStringMap options) { + return getTable(null, null, options).schema(); + } + + @Override + public Transform[] inferPartitioning(CaseInsensitiveStringMap options) { + return getTable(null, null, options).partitioning(); + } + + @Override + public org.apache.spark.sql.connector.catalog.Table getTable( + StructType schema, Transform[] partitioning, Map properties) { + Preconditions.checkArgument( + properties.containsKey(TABLE_NAME), "Missing property " + TABLE_NAME); + String tableName = properties.get(TABLE_NAME); + Preconditions.checkArgument(TABLE_MAP.containsKey(tableName), "Table missing " + tableName); + return TABLE_MAP.get(tableName); + } + + @Override + public boolean supportsExternalMetadata() { + return false; + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/NestedRecord.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/NestedRecord.java new file mode 100644 index 000000000000..ca36bfd4938b --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/NestedRecord.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.relocated.com.google.common.base.MoreObjects; +import org.apache.iceberg.relocated.com.google.common.base.Objects; + +public class NestedRecord { + private long innerId; + private String innerName; + + public NestedRecord() {} + + public NestedRecord(long innerId, String innerName) { + this.innerId = innerId; + this.innerName = innerName; + } + + public long getInnerId() { + return innerId; + } + + public String getInnerName() { + return innerName; + } + + public void setInnerId(long iId) { + innerId = iId; + } + + public void setInnerName(String name) { + innerName = name; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + NestedRecord that = (NestedRecord) o; + return innerId == that.innerId && Objects.equal(innerName, that.innerName); + } + + @Override + public int hashCode() { + return Objects.hashCode(innerId, innerName); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("innerId", innerId) + .add("innerName", innerName) + .toString(); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/SimpleRecord.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/SimpleRecord.java new file mode 100644 index 000000000000..550e20b9338e --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/SimpleRecord.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.relocated.com.google.common.base.Objects; + +public class SimpleRecord { + private Integer id; + private String data; + + public SimpleRecord() {} + + public SimpleRecord(Integer id, String data) { + this.id = id; + this.data = data; + } + + public Integer getId() { + return id; + } + + public void setId(Integer id) { + this.id = id; + } + + public String getData() { + return data; + } + + public void setData(String data) { + this.data = data; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + SimpleRecord record = (SimpleRecord) o; + return Objects.equal(id, record.id) && Objects.equal(data, record.data); + } + + @Override + public int hashCode() { + return Objects.hashCode(id, data); + } + + @Override + public String toString() { + StringBuilder buffer = new StringBuilder(); + buffer.append("{\"id\"="); + buffer.append(id); + buffer.append(",\"data\"=\""); + buffer.append(data); + buffer.append("\"}"); + return buffer.toString(); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/SparkSQLExecutionHelper.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/SparkSQLExecutionHelper.java new file mode 100644 index 000000000000..cdc380b1b6be --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/SparkSQLExecutionHelper.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.time.Duration; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.execution.ui.SQLAppStatusStore; +import org.apache.spark.sql.execution.ui.SQLExecutionUIData; +import org.apache.spark.sql.execution.ui.SQLPlanMetric; +import org.awaitility.Awaitility; +import scala.Option; + +public class SparkSQLExecutionHelper { + + private SparkSQLExecutionHelper() {} + + /** + * Finds the value of a specified metric for the last SQL query that was executed. Metric values + * are stored in the `SQLAppStatusStore` as strings. + * + * @param spark SparkSession used to run the SQL query + * @param metricName name of the metric + * @return value of the metric + */ + public static String lastExecutedMetricValue(SparkSession spark, String metricName) { + SQLAppStatusStore statusStore = spark.sharedState().statusStore(); + SQLExecutionUIData lastExecution = statusStore.executionsList().last(); + Option sqlPlanMetric = + lastExecution.metrics().find(metric -> metric.name().equals(metricName)); + assertThat(sqlPlanMetric.isDefined()) + .as(String.format("Metric '%s' not found in last execution", metricName)) + .isTrue(); + long metricId = sqlPlanMetric.get().accumulatorId(); + + // Refresh metricValues, they will remain null until the execution is complete and metrics are + // aggregated + Awaitility.await() + .atMost(Duration.ofSeconds(3)) + .pollInterval(Duration.ofMillis(100)) + .untilAsserted( + () -> assertThat(statusStore.execution(lastExecution.executionId()).get()).isNotNull()); + + SQLExecutionUIData exec = statusStore.execution(lastExecution.executionId()).get(); + + assertThat(exec.metricValues()).as("Metric values were not finalized").isNotNull(); + String metricValue = exec.metricValues().get(metricId).getOrElse(null); + assertThat(metricValue) + .as(String.format("Metric '%s' was not finalized", metricName)) + .isNotNull(); + return metricValue; + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestAvroScan.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestAvroScan.java new file mode 100644 index 000000000000..8345a4e0a697 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestAvroScan.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.Files.localOutput; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.util.List; +import java.util.UUID; +import org.apache.avro.generic.GenericData.Record; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.spark.data.AvroDataTest; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.io.TempDir; + +public class TestAvroScan extends AvroDataTest { + private static final Configuration CONF = new Configuration(); + + @TempDir private Path temp; + + private static SparkSession spark = null; + + @BeforeAll + public static void startSpark() { + TestAvroScan.spark = SparkSession.builder().master("local[2]").getOrCreate(); + } + + @AfterAll + public static void stopSpark() { + SparkSession currentSpark = TestAvroScan.spark; + TestAvroScan.spark = null; + currentSpark.stop(); + } + + @Override + protected void writeAndValidate(Schema schema) throws IOException { + File parent = temp.resolve("avro").toFile(); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + dataFolder.mkdirs(); + + File avroFile = + new File(dataFolder, FileFormat.AVRO.addExtension(UUID.randomUUID().toString())); + + HadoopTables tables = new HadoopTables(CONF); + Table table = tables.create(schema, PartitionSpec.unpartitioned(), location.toString()); + + // Important: use the table's schema for the rest of the test + // When tables are created, the column ids are reassigned. + Schema tableSchema = table.schema(); + + List expected = RandomData.generateList(tableSchema, 100, 1L); + + try (FileAppender writer = + Avro.write(localOutput(avroFile)).schema(tableSchema).build()) { + writer.addAll(expected); + } + + DataFile file = + DataFiles.builder(PartitionSpec.unpartitioned()) + .withRecordCount(100) + .withFileSizeInBytes(avroFile.length()) + .withPath(avroFile.toString()) + .build(); + + table.newAppend().appendFile(file).commit(); + + Dataset df = spark.read().format("iceberg").load(location.toString()); + + List rows = df.collectAsList(); + assertThat(rows).as("Should contain 100 rows").hasSize(100); + + for (int i = 0; i < expected.size(); i += 1) { + TestHelpers.assertEqualsSafe(tableSchema.asStruct(), expected.get(i), rows.get(i)); + } + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestBaseReader.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestBaseReader.java new file mode 100644 index 000000000000..a6d7d4827c0d --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestBaseReader.java @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.FileFormat.PARQUET; +import static org.apache.iceberg.Files.localOutput; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; +import org.apache.avro.generic.GenericData; +import org.apache.iceberg.AppendFiles; +import org.apache.iceberg.BaseCombinedScanTask; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.io.CloseableIterator; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.types.Types; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +public class TestBaseReader { + + @TempDir private Path temp; + + private Table table; + + // Simulates the closeable iterator of data to be read + private static class CloseableIntegerRange implements CloseableIterator { + boolean closed; + Iterator iter; + + CloseableIntegerRange(long range) { + this.closed = false; + this.iter = IntStream.range(0, (int) range).iterator(); + } + + @Override + public void close() { + this.closed = true; + } + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public Integer next() { + return iter.next(); + } + } + + // Main reader class to test base class iteration logic. + // Keeps track of iterator closure. + private static class ClosureTrackingReader extends BaseReader { + private final Map tracker = Maps.newHashMap(); + + ClosureTrackingReader(Table table, List tasks) { + super(table, new BaseCombinedScanTask(tasks), null, null, false); + } + + @Override + protected Stream> referencedFiles(FileScanTask task) { + return Stream.of(); + } + + @Override + protected CloseableIterator open(FileScanTask task) { + CloseableIntegerRange intRange = new CloseableIntegerRange(task.file().recordCount()); + tracker.put(getKey(task), intRange); + return intRange; + } + + public Boolean isIteratorClosed(FileScanTask task) { + return tracker.get(getKey(task)).closed; + } + + public Boolean hasIterator(FileScanTask task) { + return tracker.containsKey(getKey(task)); + } + + private String getKey(FileScanTask task) { + return task.file().location(); + } + } + + @Test + public void testClosureOnDataExhaustion() throws IOException { + Integer totalTasks = 10; + Integer recordPerTask = 10; + List tasks = createFileScanTasks(totalTasks, recordPerTask); + + ClosureTrackingReader reader = new ClosureTrackingReader(table, tasks); + + int countRecords = 0; + while (reader.next()) { + countRecords += 1; + assertThat(reader.get()).as("Reader should return non-null value").isNotNull(); + } + + assertThat(totalTasks * recordPerTask) + .as("Reader returned incorrect number of records") + .isEqualTo(countRecords); + tasks.forEach( + t -> + assertThat(reader.isIteratorClosed(t)) + .as("All iterators should be closed after read exhausion") + .isTrue()); + } + + @Test + public void testClosureDuringIteration() throws IOException { + Integer totalTasks = 2; + Integer recordPerTask = 1; + List tasks = createFileScanTasks(totalTasks, recordPerTask); + assertThat(tasks).hasSize(2); + FileScanTask firstTask = tasks.get(0); + FileScanTask secondTask = tasks.get(1); + + ClosureTrackingReader reader = new ClosureTrackingReader(table, tasks); + + // Total of 2 elements + assertThat(reader.next()).isTrue(); + assertThat(reader.isIteratorClosed(firstTask)) + .as("First iter should not be closed on its last element") + .isFalse(); + + assertThat(reader.next()).isTrue(); + assertThat(reader.isIteratorClosed(firstTask)) + .as("First iter should be closed after moving to second iter") + .isTrue(); + assertThat(reader.isIteratorClosed(secondTask)) + .as("Second iter should not be closed on its last element") + .isFalse(); + + assertThat(reader.next()).isFalse(); + assertThat(reader.isIteratorClosed(firstTask)).isTrue(); + assertThat(reader.isIteratorClosed(secondTask)).isTrue(); + } + + @Test + public void testClosureWithoutAnyRead() throws IOException { + Integer totalTasks = 10; + Integer recordPerTask = 10; + List tasks = createFileScanTasks(totalTasks, recordPerTask); + + ClosureTrackingReader reader = new ClosureTrackingReader(table, tasks); + + reader.close(); + + tasks.forEach( + t -> + assertThat(reader.hasIterator(t)) + .as("Iterator should not be created eagerly for tasks") + .isFalse()); + } + + @Test + public void testExplicitClosure() throws IOException { + Integer totalTasks = 10; + Integer recordPerTask = 10; + List tasks = createFileScanTasks(totalTasks, recordPerTask); + + ClosureTrackingReader reader = new ClosureTrackingReader(table, tasks); + + Integer halfDataSize = (totalTasks * recordPerTask) / 2; + for (int i = 0; i < halfDataSize; i++) { + assertThat(reader.next()).as("Reader should have some element").isTrue(); + assertThat(reader.get()).as("Reader should return non-null value").isNotNull(); + } + + reader.close(); + + // Some tasks might have not been opened yet, so we don't have corresponding tracker for it. + // But all that have been created must be closed. + tasks.forEach( + t -> { + if (reader.hasIterator(t)) { + assertThat(reader.isIteratorClosed(t)) + .as("Iterator should be closed after read exhausion") + .isTrue(); + } + }); + } + + @Test + public void testIdempotentExplicitClosure() throws IOException { + Integer totalTasks = 10; + Integer recordPerTask = 10; + List tasks = createFileScanTasks(totalTasks, recordPerTask); + + ClosureTrackingReader reader = new ClosureTrackingReader(table, tasks); + + // Total 100 elements, only 5 iterators have been created + for (int i = 0; i < 45; i++) { + assertThat(reader.next()).as("Reader should have some element").isTrue(); + assertThat(reader.get()).as("Reader should return non-null value").isNotNull(); + } + + for (int closeAttempt = 0; closeAttempt < 5; closeAttempt++) { + reader.close(); + for (int i = 0; i < 5; i++) { + assertThat(reader.isIteratorClosed(tasks.get(i))) + .as("Iterator should be closed after read exhausion") + .isTrue(); + } + for (int i = 5; i < 10; i++) { + assertThat(reader.hasIterator(tasks.get(i))) + .as("Iterator should not be created eagerly for tasks") + .isFalse(); + } + } + } + + private List createFileScanTasks(Integer totalTasks, Integer recordPerTask) + throws IOException { + String desc = "make_scan_tasks"; + File parent = temp.resolve(desc).toFile(); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + assertThat(dataFolder.mkdirs()).as("mkdirs should succeed").isTrue(); + + Schema schema = new Schema(Types.NestedField.required(0, "id", Types.LongType.get())); + + try { + this.table = TestTables.create(location, desc, schema, PartitionSpec.unpartitioned()); + // Important: use the table's schema for the rest of the test + // When tables are created, the column ids are reassigned. + Schema tableSchema = table.schema(); + List expected = RandomData.generateList(tableSchema, recordPerTask, 1L); + + AppendFiles appendFiles = table.newAppend(); + for (int i = 0; i < totalTasks; i++) { + File parquetFile = new File(dataFolder, PARQUET.addExtension(UUID.randomUUID().toString())); + try (FileAppender writer = + Parquet.write(localOutput(parquetFile)).schema(tableSchema).build()) { + writer.addAll(expected); + } + DataFile file = + DataFiles.builder(PartitionSpec.unpartitioned()) + .withFileSizeInBytes(parquetFile.length()) + .withPath(parquetFile.toString()) + .withRecordCount(recordPerTask) + .build(); + appendFiles.appendFile(file); + } + appendFiles.commit(); + + return StreamSupport.stream(table.newScan().planFiles().spliterator(), false) + .collect(Collectors.toList()); + } finally { + TestTables.clearTables(); + } + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestChangelogReader.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestChangelogReader.java new file mode 100644 index 000000000000..52d6ff8c9c8b --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestChangelogReader.java @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.iceberg.ChangelogOperation; +import org.apache.iceberg.ChangelogScanTask; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.Files; +import org.apache.iceberg.IncrementalChangelogScan; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TestHelpers; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.data.FileHelpers; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.TestBase; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +public class TestChangelogReader extends TestBase { + private static final Schema SCHEMA = + new Schema( + required(1, "id", Types.IntegerType.get()), optional(2, "data", Types.StringType.get())); + private static final PartitionSpec SPEC = + PartitionSpec.builderFor(SCHEMA).bucket("data", 16).build(); + private final List records1 = Lists.newArrayList(); + private final List records2 = Lists.newArrayList(); + + private Table table; + private DataFile dataFile1; + private DataFile dataFile2; + + @TempDir private Path temp; + + @BeforeEach + public void before() throws IOException { + table = catalog.createTable(TableIdentifier.of("default", "test"), SCHEMA, SPEC); + // create some data + GenericRecord record = GenericRecord.create(table.schema()); + records1.add(record.copy("id", 29, "data", "a")); + records1.add(record.copy("id", 43, "data", "b")); + records1.add(record.copy("id", 61, "data", "c")); + records1.add(record.copy("id", 89, "data", "d")); + + records2.add(record.copy("id", 100, "data", "e")); + records2.add(record.copy("id", 121, "data", "f")); + records2.add(record.copy("id", 122, "data", "g")); + + // write data to files + dataFile1 = writeDataFile(records1); + dataFile2 = writeDataFile(records2); + } + + @AfterEach + public void after() { + catalog.dropTable(TableIdentifier.of("default", "test")); + } + + @Test + public void testInsert() throws IOException { + table.newAppend().appendFile(dataFile1).commit(); + long snapshotId1 = table.currentSnapshot().snapshotId(); + + table.newAppend().appendFile(dataFile2).commit(); + long snapshotId2 = table.currentSnapshot().snapshotId(); + + CloseableIterable> taskGroups = newScan().planTasks(); + + List rows = Lists.newArrayList(); + + for (ScanTaskGroup taskGroup : taskGroups) { + ChangelogRowReader reader = + new ChangelogRowReader(table, taskGroup, table.schema(), table.schema(), false); + while (reader.next()) { + rows.add(reader.get().copy()); + } + reader.close(); + } + + rows.sort((r1, r2) -> r1.getInt(0) - r2.getInt(0)); + + List expectedRows = Lists.newArrayList(); + addExpectedRows(expectedRows, ChangelogOperation.INSERT, snapshotId1, 0, records1); + addExpectedRows(expectedRows, ChangelogOperation.INSERT, snapshotId2, 1, records2); + + assertEquals("Should have expected rows", expectedRows, internalRowsToJava(rows)); + } + + @Test + public void testDelete() throws IOException { + table.newAppend().appendFile(dataFile1).commit(); + long snapshotId1 = table.currentSnapshot().snapshotId(); + + table.newDelete().deleteFile(dataFile1).commit(); + long snapshotId2 = table.currentSnapshot().snapshotId(); + + CloseableIterable> taskGroups = + newScan().fromSnapshotExclusive(snapshotId1).planTasks(); + + List rows = Lists.newArrayList(); + + for (ScanTaskGroup taskGroup : taskGroups) { + ChangelogRowReader reader = + new ChangelogRowReader(table, taskGroup, table.schema(), table.schema(), false); + while (reader.next()) { + rows.add(reader.get().copy()); + } + reader.close(); + } + + rows.sort((r1, r2) -> r1.getInt(0) - r2.getInt(0)); + + List expectedRows = Lists.newArrayList(); + addExpectedRows(expectedRows, ChangelogOperation.DELETE, snapshotId2, 0, records1); + + assertEquals("Should have expected rows", expectedRows, internalRowsToJava(rows)); + } + + @Test + public void testDataFileRewrite() throws IOException { + table.newAppend().appendFile(dataFile1).commit(); + table.newAppend().appendFile(dataFile2).commit(); + long snapshotId2 = table.currentSnapshot().snapshotId(); + + table + .newRewrite() + .rewriteFiles(ImmutableSet.of(dataFile1), ImmutableSet.of(dataFile2)) + .commit(); + + // the rewrite operation should generate no Changelog rows + CloseableIterable> taskGroups = + newScan().fromSnapshotExclusive(snapshotId2).planTasks(); + + List rows = Lists.newArrayList(); + + for (ScanTaskGroup taskGroup : taskGroups) { + ChangelogRowReader reader = + new ChangelogRowReader(table, taskGroup, table.schema(), table.schema(), false); + while (reader.next()) { + rows.add(reader.get().copy()); + } + reader.close(); + } + + assertThat(rows).as("Should have no rows").hasSize(0); + } + + @Test + public void testMixDeleteAndInsert() throws IOException { + table.newAppend().appendFile(dataFile1).commit(); + long snapshotId1 = table.currentSnapshot().snapshotId(); + + table.newDelete().deleteFile(dataFile1).commit(); + long snapshotId2 = table.currentSnapshot().snapshotId(); + + table.newAppend().appendFile(dataFile2).commit(); + long snapshotId3 = table.currentSnapshot().snapshotId(); + + CloseableIterable> taskGroups = newScan().planTasks(); + + List rows = Lists.newArrayList(); + + for (ScanTaskGroup taskGroup : taskGroups) { + ChangelogRowReader reader = + new ChangelogRowReader(table, taskGroup, table.schema(), table.schema(), false); + while (reader.next()) { + rows.add(reader.get().copy()); + } + reader.close(); + } + + // order by the change ordinal + rows.sort( + (r1, r2) -> { + if (r1.getInt(3) != r2.getInt(3)) { + return r1.getInt(3) - r2.getInt(3); + } else { + return r1.getInt(0) - r2.getInt(0); + } + }); + + List expectedRows = Lists.newArrayList(); + addExpectedRows(expectedRows, ChangelogOperation.INSERT, snapshotId1, 0, records1); + addExpectedRows(expectedRows, ChangelogOperation.DELETE, snapshotId2, 1, records1); + addExpectedRows(expectedRows, ChangelogOperation.INSERT, snapshotId3, 2, records2); + + assertEquals("Should have expected rows", expectedRows, internalRowsToJava(rows)); + } + + private IncrementalChangelogScan newScan() { + return table.newIncrementalChangelogScan(); + } + + private List addExpectedRows( + List expectedRows, + ChangelogOperation operation, + long snapshotId, + int changeOrdinal, + List records) { + records.forEach( + r -> + expectedRows.add(row(r.get(0), r.get(1), operation.name(), changeOrdinal, snapshotId))); + return expectedRows; + } + + protected List internalRowsToJava(List rows) { + return rows.stream().map(this::toJava).collect(Collectors.toList()); + } + + private Object[] toJava(InternalRow row) { + Object[] values = new Object[row.numFields()]; + values[0] = row.getInt(0); + values[1] = row.getString(1); + values[2] = row.getString(2); + values[3] = row.getInt(3); + values[4] = row.getLong(4); + return values; + } + + private DataFile writeDataFile(List records) throws IOException { + // records all use IDs that are in bucket id_bucket=0 + return FileHelpers.writeDataFile( + table, + Files.localOutput(File.createTempFile("junit", null, temp.toFile())), + TestHelpers.Row.of(0), + records); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestCompressionSettings.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestCompressionSettings.java new file mode 100644 index 000000000000..f411920a5dcc --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestCompressionSettings.java @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.FileFormat.AVRO; +import static org.apache.iceberg.FileFormat.ORC; +import static org.apache.iceberg.FileFormat.PARQUET; +import static org.apache.iceberg.RowLevelOperationMode.MERGE_ON_READ; +import static org.apache.iceberg.TableProperties.AVRO_COMPRESSION; +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.DELETE_AVRO_COMPRESSION; +import static org.apache.iceberg.TableProperties.DELETE_DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.DELETE_MODE; +import static org.apache.iceberg.TableProperties.DELETE_ORC_COMPRESSION; +import static org.apache.iceberg.TableProperties.DELETE_PARQUET_COMPRESSION; +import static org.apache.iceberg.TableProperties.FORMAT_VERSION; +import static org.apache.iceberg.TableProperties.ORC_COMPRESSION; +import static org.apache.iceberg.TableProperties.PARQUET_COMPRESSION; +import static org.apache.iceberg.spark.SparkSQLProperties.COMPRESSION_CODEC; +import static org.apache.iceberg.spark.SparkSQLProperties.COMPRESSION_LEVEL; +import static org.apache.iceberg.spark.SparkSQLProperties.COMPRESSION_STRATEGY; +import static org.apache.parquet.format.converter.ParquetMetadataConverter.NO_FILTER; +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import java.util.Map; +import org.apache.avro.file.DataFileConstants; +import org.apache.avro.file.DataFileReader; +import org.apache.avro.generic.GenericDatumReader; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.AvroFSInput; +import org.apache.hadoop.fs.FileContext; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.ManifestFiles; +import org.apache.iceberg.ManifestReader; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.SizeBasedFileRewriter; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.CatalogTestBase; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.spark.actions.SparkActions; +import org.apache.orc.OrcFile; +import org.apache.orc.Reader; +import org.apache.parquet.hadoop.ParquetFileReader; +import org.apache.parquet.hadoop.metadata.ParquetMetadata; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestCompressionSettings extends CatalogTestBase { + + private static final Configuration CONF = new Configuration(); + private static final String TABLE_NAME = "testWriteData"; + + private static SparkSession spark = null; + + @Parameter(index = 3) + private FileFormat format; + + @Parameter(index = 4) + private Map properties; + + @TempDir private java.nio.file.Path temp; + + @Parameters( + name = + "catalogName = {0}, implementation = {1}, config = {2}, format = {3}, properties = {4}") + public static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.SPARK.catalogName(), + SparkCatalogConfig.SPARK.implementation(), + SparkCatalogConfig.SPARK.properties(), + PARQUET, + ImmutableMap.of(COMPRESSION_CODEC, "zstd", COMPRESSION_LEVEL, "1") + }, + { + SparkCatalogConfig.SPARK.catalogName(), + SparkCatalogConfig.SPARK.implementation(), + SparkCatalogConfig.SPARK.properties(), + PARQUET, + ImmutableMap.of(COMPRESSION_CODEC, "gzip") + }, + { + SparkCatalogConfig.SPARK.catalogName(), + SparkCatalogConfig.SPARK.implementation(), + SparkCatalogConfig.SPARK.properties(), + ORC, + ImmutableMap.of(COMPRESSION_CODEC, "zstd", COMPRESSION_STRATEGY, "speed") + }, + { + SparkCatalogConfig.SPARK.catalogName(), + SparkCatalogConfig.SPARK.implementation(), + SparkCatalogConfig.SPARK.properties(), + ORC, + ImmutableMap.of(COMPRESSION_CODEC, "zstd", COMPRESSION_STRATEGY, "compression") + }, + { + SparkCatalogConfig.SPARK.catalogName(), + SparkCatalogConfig.SPARK.implementation(), + SparkCatalogConfig.SPARK.properties(), + AVRO, + ImmutableMap.of(COMPRESSION_CODEC, "snappy", COMPRESSION_LEVEL, "3") + } + }; + } + + @BeforeAll + public static void startSpark() { + TestCompressionSettings.spark = SparkSession.builder().master("local[2]").getOrCreate(); + } + + @BeforeEach + public void resetSpecificConfigurations() { + spark.conf().unset(COMPRESSION_CODEC); + spark.conf().unset(COMPRESSION_LEVEL); + spark.conf().unset(COMPRESSION_STRATEGY); + } + + @AfterEach + public void afterEach() { + spark.sql(String.format("DROP TABLE IF EXISTS %s", TABLE_NAME)); + } + + @AfterAll + public static void stopSpark() { + SparkSession currentSpark = TestCompressionSettings.spark; + TestCompressionSettings.spark = null; + currentSpark.stop(); + } + + @TestTemplate + public void testWriteDataWithDifferentSetting() throws Exception { + sql("CREATE TABLE %s (id int, data string) USING iceberg", TABLE_NAME); + Map tableProperties = Maps.newHashMap(); + tableProperties.put(PARQUET_COMPRESSION, "gzip"); + tableProperties.put(AVRO_COMPRESSION, "gzip"); + tableProperties.put(ORC_COMPRESSION, "zlib"); + tableProperties.put(DELETE_PARQUET_COMPRESSION, "gzip"); + tableProperties.put(DELETE_AVRO_COMPRESSION, "gzip"); + tableProperties.put(DELETE_ORC_COMPRESSION, "zlib"); + tableProperties.put(DELETE_MODE, MERGE_ON_READ.modeName()); + tableProperties.put(FORMAT_VERSION, "2"); + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' '%s')", TABLE_NAME, DEFAULT_FILE_FORMAT, format); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' '%s')", + TABLE_NAME, DELETE_DEFAULT_FILE_FORMAT, format); + for (Map.Entry entry : tableProperties.entrySet()) { + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' '%s')", + TABLE_NAME, entry.getKey(), entry.getValue()); + } + + List expectedOrigin = Lists.newArrayList(); + for (int i = 0; i < 1000; i++) { + expectedOrigin.add(new SimpleRecord(i, "hello world" + i)); + } + + Dataset df = spark.createDataFrame(expectedOrigin, SimpleRecord.class); + + for (Map.Entry entry : properties.entrySet()) { + spark.conf().set(entry.getKey(), entry.getValue()); + } + + assertSparkConf(); + + df.select("id", "data") + .writeTo(TABLE_NAME) + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .append(); + Table table = catalog.loadTable(TableIdentifier.of("default", TABLE_NAME)); + List manifestFiles = table.currentSnapshot().dataManifests(table.io()); + try (ManifestReader reader = ManifestFiles.read(manifestFiles.get(0), table.io())) { + DataFile file = reader.iterator().next(); + InputFile inputFile = table.io().newInputFile(file.location()); + assertThat(getCompressionType(inputFile)) + .isEqualToIgnoringCase(properties.get(COMPRESSION_CODEC)); + } + + sql("DELETE from %s where id < 100", TABLE_NAME); + + table.refresh(); + List deleteManifestFiles = table.currentSnapshot().deleteManifests(table.io()); + Map specMap = Maps.newHashMap(); + specMap.put(0, PartitionSpec.unpartitioned()); + try (ManifestReader reader = + ManifestFiles.readDeleteManifest(deleteManifestFiles.get(0), table.io(), specMap)) { + DeleteFile file = reader.iterator().next(); + InputFile inputFile = table.io().newInputFile(file.location()); + assertThat(getCompressionType(inputFile)) + .isEqualToIgnoringCase(properties.get(COMPRESSION_CODEC)); + } + + SparkActions.get(spark) + .rewritePositionDeletes(table) + .option(SizeBasedFileRewriter.REWRITE_ALL, "true") + .execute(); + table.refresh(); + deleteManifestFiles = table.currentSnapshot().deleteManifests(table.io()); + try (ManifestReader reader = + ManifestFiles.readDeleteManifest(deleteManifestFiles.get(0), table.io(), specMap)) { + DeleteFile file = reader.iterator().next(); + InputFile inputFile = table.io().newInputFile(file.location()); + assertThat(getCompressionType(inputFile)) + .isEqualToIgnoringCase(properties.get(COMPRESSION_CODEC)); + } + } + + private String getCompressionType(InputFile inputFile) throws Exception { + switch (format) { + case ORC: + OrcFile.ReaderOptions readerOptions = OrcFile.readerOptions(CONF).useUTCTimestamp(true); + Reader orcReader = OrcFile.createReader(new Path(inputFile.location()), readerOptions); + return orcReader.getCompressionKind().name(); + case PARQUET: + ParquetMetadata footer = + ParquetFileReader.readFooter(CONF, new Path(inputFile.location()), NO_FILTER); + return footer.getBlocks().get(0).getColumns().get(0).getCodec().name(); + default: + FileContext fc = FileContext.getFileContext(CONF); + GenericDatumReader reader = new GenericDatumReader<>(); + DataFileReader fileReader = + (DataFileReader) + DataFileReader.openReader( + new AvroFSInput(fc, new Path(inputFile.location())), reader); + return fileReader.getMetaString(DataFileConstants.CODEC); + } + } + + private void assertSparkConf() { + String[] propertiesToCheck = {COMPRESSION_CODEC, COMPRESSION_LEVEL, COMPRESSION_STRATEGY}; + for (String prop : propertiesToCheck) { + String expected = properties.getOrDefault(prop, null); + String actual = spark.conf().get(prop, null); + assertThat(actual).isEqualToIgnoringCase(expected); + } + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWriterV2.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWriterV2.java new file mode 100644 index 000000000000..7404b18d14b2 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWriterV2.java @@ -0,0 +1,381 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.math.BigDecimal; +import java.util.List; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.TestBaseWithCatalog; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.apache.spark.sql.internal.SQLConf; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestDataFrameWriterV2 extends TestBaseWithCatalog { + @BeforeEach + public void createTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + } + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testMergeSchemaFailsWithoutWriterOption() throws Exception { + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s'='true')", + tableName, TableProperties.SPARK_WRITE_ACCEPT_ANY_SCHEMA); + + Dataset twoColDF = + jsonToDF( + "id bigint, data string", + "{ \"id\": 1, \"data\": \"a\" }", + "{ \"id\": 2, \"data\": \"b\" }"); + + twoColDF.writeTo(tableName).append(); + + assertEquals( + "Should have initial 2-column rows", + ImmutableList.of(row(1L, "a"), row(2L, "b")), + sql("select * from %s order by id", tableName)); + + Dataset threeColDF = + jsonToDF( + "id bigint, data string, new_col float", + "{ \"id\": 3, \"data\": \"c\", \"new_col\": 12.06 }", + "{ \"id\": 4, \"data\": \"d\", \"new_col\": 14.41 }"); + + // this has a different error message than the case without accept-any-schema because it uses + // Iceberg checks + assertThatThrownBy(() -> threeColDF.writeTo(tableName).append()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Field new_col not found in source schema"); + } + + @TestTemplate + public void testMergeSchemaWithoutAcceptAnySchema() throws Exception { + Dataset twoColDF = + jsonToDF( + "id bigint, data string", + "{ \"id\": 1, \"data\": \"a\" }", + "{ \"id\": 2, \"data\": \"b\" }"); + + twoColDF.writeTo(tableName).append(); + + assertEquals( + "Should have initial 2-column rows", + ImmutableList.of(row(1L, "a"), row(2L, "b")), + sql("select * from %s order by id", tableName)); + + Dataset threeColDF = + jsonToDF( + "id bigint, data string, new_col float", + "{ \"id\": 3, \"data\": \"c\", \"new_col\": 12.06 }", + "{ \"id\": 4, \"data\": \"d\", \"new_col\": 14.41 }"); + + assertThatThrownBy(() -> threeColDF.writeTo(tableName).option("merge-schema", "true").append()) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + "Cannot write to `testhadoop`.`default`.`table`, the reason is too many data columns"); + } + + @TestTemplate + public void testMergeSchemaSparkProperty() throws Exception { + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s'='true')", + tableName, TableProperties.SPARK_WRITE_ACCEPT_ANY_SCHEMA); + + Dataset twoColDF = + jsonToDF( + "id bigint, data string", + "{ \"id\": 1, \"data\": \"a\" }", + "{ \"id\": 2, \"data\": \"b\" }"); + + twoColDF.writeTo(tableName).append(); + + assertEquals( + "Should have initial 2-column rows", + ImmutableList.of(row(1L, "a"), row(2L, "b")), + sql("select * from %s order by id", tableName)); + + Dataset threeColDF = + jsonToDF( + "id bigint, data string, new_col float", + "{ \"id\": 3, \"data\": \"c\", \"new_col\": 12.06 }", + "{ \"id\": 4, \"data\": \"d\", \"new_col\": 14.41 }"); + + threeColDF.writeTo(tableName).option("mergeSchema", "true").append(); + + assertEquals( + "Should have 3-column rows", + ImmutableList.of( + row(1L, "a", null), row(2L, "b", null), row(3L, "c", 12.06F), row(4L, "d", 14.41F)), + sql("select * from %s order by id", tableName)); + } + + @TestTemplate + public void testMergeSchemaIcebergProperty() throws Exception { + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s'='true')", + tableName, TableProperties.SPARK_WRITE_ACCEPT_ANY_SCHEMA); + + Dataset twoColDF = + jsonToDF( + "id bigint, data string", + "{ \"id\": 1, \"data\": \"a\" }", + "{ \"id\": 2, \"data\": \"b\" }"); + + twoColDF.writeTo(tableName).append(); + + assertEquals( + "Should have initial 2-column rows", + ImmutableList.of(row(1L, "a"), row(2L, "b")), + sql("select * from %s order by id", tableName)); + + Dataset threeColDF = + jsonToDF( + "id bigint, data string, new_col float", + "{ \"id\": 3, \"data\": \"c\", \"new_col\": 12.06 }", + "{ \"id\": 4, \"data\": \"d\", \"new_col\": 14.41 }"); + + threeColDF.writeTo(tableName).option("merge-schema", "true").append(); + + assertEquals( + "Should have 3-column rows", + ImmutableList.of( + row(1L, "a", null), row(2L, "b", null), row(3L, "c", 12.06F), row(4L, "d", 14.41F)), + sql("select * from %s order by id", tableName)); + } + + @TestTemplate + public void testWriteWithCaseSensitiveOption() throws NoSuchTableException, ParseException { + SparkSession sparkSession = spark.cloneSession(); + sparkSession + .sql( + String.format( + "ALTER TABLE %s SET TBLPROPERTIES ('%s'='true')", + tableName, TableProperties.SPARK_WRITE_ACCEPT_ANY_SCHEMA)) + .collect(); + + String schema = "ID bigint, DaTa string"; + ImmutableList records = + ImmutableList.of("{ \"id\": 1, \"data\": \"a\" }", "{ \"id\": 2, \"data\": \"b\" }"); + + // disable spark.sql.caseSensitive + sparkSession.sql(String.format("SET %s=false", SQLConf.CASE_SENSITIVE().key())); + Dataset jsonDF = + sparkSession.createDataset(ImmutableList.copyOf(records), Encoders.STRING()); + Dataset ds = sparkSession.read().schema(schema).json(jsonDF); + // write should succeed + ds.writeTo(tableName).option("merge-schema", "true").option("check-ordering", "false").append(); + List fields = + Spark3Util.loadIcebergTable(sparkSession, tableName).schema().asStruct().fields(); + // Additional columns should not be created + assertThat(fields).hasSize(2); + + // enable spark.sql.caseSensitive + sparkSession.sql(String.format("SET %s=true", SQLConf.CASE_SENSITIVE().key())); + ds.writeTo(tableName).option("merge-schema", "true").option("check-ordering", "false").append(); + fields = Spark3Util.loadIcebergTable(sparkSession, tableName).schema().asStruct().fields(); + assertThat(fields).hasSize(4); + } + + @TestTemplate + public void testMergeSchemaSparkConfiguration() throws Exception { + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s'='true')", + tableName, TableProperties.SPARK_WRITE_ACCEPT_ANY_SCHEMA); + Dataset twoColDF = + jsonToDF( + "id bigint, data string", + "{ \"id\": 1, \"data\": \"a\" }", + "{ \"id\": 2, \"data\": \"b\" }"); + + twoColDF.writeTo(tableName).append(); + + assertEquals( + "Should have initial 2-column rows", + ImmutableList.of(row(1L, "a"), row(2L, "b")), + sql("select * from %s order by id", tableName)); + spark.conf().set("spark.sql.iceberg.merge-schema", "true"); + Dataset threeColDF = + jsonToDF( + "id bigint, data string, salary float", + "{ \"id\": 3, \"data\": \"c\", \"salary\": 120000.34 }", + "{ \"id\": 4, \"data\": \"d\", \"salary\": 140000.56 }"); + + threeColDF.writeTo(tableName).append(); + assertEquals( + "Should have 3-column rows", + ImmutableList.of( + row(1L, "a", null), + row(2L, "b", null), + row(3L, "c", 120000.34F), + row(4L, "d", 140000.56F)), + sql("select * from %s order by id", tableName)); + } + + @TestTemplate + public void testMergeSchemaIgnoreCastingLongToInt() throws Exception { + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s'='true')", + tableName, TableProperties.SPARK_WRITE_ACCEPT_ANY_SCHEMA); + + Dataset bigintDF = + jsonToDF( + "id bigint, data string", + "{ \"id\": 1, \"data\": \"a\" }", + "{ \"id\": 2, \"data\": \"b\" }"); + + bigintDF.writeTo(tableName).append(); + + assertEquals( + "Should have initial rows with long column", + ImmutableList.of(row(1L, "a"), row(2L, "b")), + sql("select * from %s order by id", tableName)); + + Dataset intDF = + jsonToDF( + "id int, data string", + "{ \"id\": 3, \"data\": \"c\" }", + "{ \"id\": 4, \"data\": \"d\" }"); + + // merge-schema=true on writes allows table schema updates when incoming data has schema changes + assertThatCode(() -> intDF.writeTo(tableName).option("merge-schema", "true").append()) + .doesNotThrowAnyException(); + + assertEquals( + "Should include new rows with unchanged long column type", + ImmutableList.of(row(1L, "a"), row(2L, "b"), row(3L, "c"), row(4L, "d")), + sql("select * from %s order by id", tableName)); + + // verify the column type did not change + Types.NestedField idField = + Spark3Util.loadIcebergTable(spark, tableName).schema().findField("id"); + assertThat(idField.type().typeId()).isEqualTo(Type.TypeID.LONG); + } + + @TestTemplate + public void testMergeSchemaIgnoreCastingDoubleToFloat() throws Exception { + removeTables(); + sql("CREATE TABLE %s (id double, data string) USING iceberg", tableName); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s'='true')", + tableName, TableProperties.SPARK_WRITE_ACCEPT_ANY_SCHEMA); + + Dataset doubleDF = + jsonToDF( + "id double, data string", + "{ \"id\": 1.0, \"data\": \"a\" }", + "{ \"id\": 2.0, \"data\": \"b\" }"); + + doubleDF.writeTo(tableName).append(); + + assertEquals( + "Should have initial rows with double column", + ImmutableList.of(row(1.0, "a"), row(2.0, "b")), + sql("select * from %s order by id", tableName)); + + Dataset floatDF = + jsonToDF( + "id float, data string", + "{ \"id\": 3.0, \"data\": \"c\" }", + "{ \"id\": 4.0, \"data\": \"d\" }"); + + // merge-schema=true on writes allows table schema updates when incoming data has schema changes + assertThatCode(() -> floatDF.writeTo(tableName).option("merge-schema", "true").append()) + .doesNotThrowAnyException(); + + assertEquals( + "Should include new rows with unchanged double column type", + ImmutableList.of(row(1.0, "a"), row(2.0, "b"), row(3.0, "c"), row(4.0, "d")), + sql("select * from %s order by id", tableName)); + + // verify the column type did not change + Types.NestedField idField = + Spark3Util.loadIcebergTable(spark, tableName).schema().findField("id"); + assertThat(idField.type().typeId()).isEqualTo(Type.TypeID.DOUBLE); + } + + @TestTemplate + public void testMergeSchemaIgnoreCastingDecimalToDecimalWithNarrowerPrecision() throws Exception { + removeTables(); + sql("CREATE TABLE %s (id decimal(6,2), data string) USING iceberg", tableName); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s'='true')", + tableName, TableProperties.SPARK_WRITE_ACCEPT_ANY_SCHEMA); + + Dataset decimalPrecision6DF = + jsonToDF( + "id decimal(6,2), data string", + "{ \"id\": 1.0, \"data\": \"a\" }", + "{ \"id\": 2.0, \"data\": \"b\" }"); + + decimalPrecision6DF.writeTo(tableName).append(); + + assertEquals( + "Should have initial rows with decimal column with precision 6", + ImmutableList.of(row(new BigDecimal("1.00"), "a"), row(new BigDecimal("2.00"), "b")), + sql("select * from %s order by id", tableName)); + + Dataset decimalPrecision4DF = + jsonToDF( + "id decimal(4,2), data string", + "{ \"id\": 3.0, \"data\": \"c\" }", + "{ \"id\": 4.0, \"data\": \"d\" }"); + + // merge-schema=true on writes allows table schema updates when incoming data has schema changes + assertThatCode( + () -> decimalPrecision4DF.writeTo(tableName).option("merge-schema", "true").append()) + .doesNotThrowAnyException(); + + assertEquals( + "Should include new rows with unchanged decimal precision", + ImmutableList.of( + row(new BigDecimal("1.00"), "a"), + row(new BigDecimal("2.00"), "b"), + row(new BigDecimal("3.00"), "c"), + row(new BigDecimal("4.00"), "d")), + sql("select * from %s order by id", tableName)); + + // verify the decimal column precision did not change + Type idFieldType = + Spark3Util.loadIcebergTable(spark, tableName).schema().findField("id").type(); + assertThat(idFieldType.typeId()).isEqualTo(Type.TypeID.DECIMAL); + Types.DecimalType decimalType = (Types.DecimalType) idFieldType; + assertThat(decimalType.precision()).isEqualTo(6); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWriterV2Coercion.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWriterV2Coercion.java new file mode 100644 index 000000000000..f51a06853a69 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWriterV2Coercion.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.TestBaseWithCatalog; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestDataFrameWriterV2Coercion extends TestBaseWithCatalog { + + @Parameters( + name = "catalogName = {0}, implementation = {1}, config = {2}, format = {3}, dataType = {4}") + public static Object[][] parameters() { + return new Object[][] { + parameter(FileFormat.AVRO, "byte"), + parameter(FileFormat.ORC, "byte"), + parameter(FileFormat.PARQUET, "byte"), + parameter(FileFormat.AVRO, "short"), + parameter(FileFormat.ORC, "short"), + parameter(FileFormat.PARQUET, "short") + }; + } + + private static Object[] parameter(FileFormat fileFormat, String dataType) { + return new Object[] { + SparkCatalogConfig.HADOOP.catalogName(), + SparkCatalogConfig.HADOOP.implementation(), + SparkCatalogConfig.HADOOP.properties(), + fileFormat, + dataType + }; + } + + @Parameter(index = 3) + private FileFormat format; + + @Parameter(index = 4) + private String dataType; + + @TestTemplate + public void testByteAndShortCoercion() { + + Dataset df = + jsonToDF( + "id " + dataType + ", data string", + "{ \"id\": 1, \"data\": \"a\" }", + "{ \"id\": 2, \"data\": \"b\" }"); + + df.writeTo(tableName).option("write-format", format.name()).createOrReplace(); + + assertEquals( + "Should have initial 2-column rows", + ImmutableList.of(row(1, "a"), row(2, "b")), + sql("select * from %s order by id", tableName)); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWrites.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWrites.java new file mode 100644 index 000000000000..42552f385137 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWrites.java @@ -0,0 +1,412 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.spark.SparkSchemaUtil.convert; +import static org.apache.iceberg.spark.data.TestHelpers.assertEqualsSafe; +import static org.apache.iceberg.spark.data.TestHelpers.assertEqualsUnsafe; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assumptions.assumeThat; + +import java.io.File; +import java.io.IOException; +import java.net.URI; +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Random; +import org.apache.avro.generic.GenericData.Record; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.Files; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.avro.AvroIterable; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.spark.data.ParameterizedAvroDataTest; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.SparkPlannedAvroReader; +import org.apache.iceberg.types.Types; +import org.apache.spark.SparkException; +import org.apache.spark.TaskContext; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.MapPartitionsFunction; +import org.apache.spark.sql.DataFrameWriter; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestDataFrameWrites extends ParameterizedAvroDataTest { + private static final Configuration CONF = new Configuration(); + + @Parameters(name = "format = {0}") + public static Collection parameters() { + return Arrays.asList("parquet", "avro", "orc"); + } + + @Parameter private String format; + + @TempDir private File location; + + private static SparkSession spark = null; + private static JavaSparkContext sc = null; + + private Map tableProperties; + + private final org.apache.spark.sql.types.StructType sparkSchema = + new org.apache.spark.sql.types.StructType( + new org.apache.spark.sql.types.StructField[] { + new org.apache.spark.sql.types.StructField( + "optionalField", + org.apache.spark.sql.types.DataTypes.StringType, + true, + org.apache.spark.sql.types.Metadata.empty()), + new org.apache.spark.sql.types.StructField( + "requiredField", + org.apache.spark.sql.types.DataTypes.StringType, + false, + org.apache.spark.sql.types.Metadata.empty()) + }); + + private final Schema icebergSchema = + new Schema( + Types.NestedField.optional(1, "optionalField", Types.StringType.get()), + Types.NestedField.required(2, "requiredField", Types.StringType.get())); + + private final List data0 = + Arrays.asList( + "{\"optionalField\": \"a1\", \"requiredField\": \"bid_001\"}", + "{\"optionalField\": \"a2\", \"requiredField\": \"bid_002\"}"); + private final List data1 = + Arrays.asList( + "{\"optionalField\": \"d1\", \"requiredField\": \"bid_101\"}", + "{\"optionalField\": \"d2\", \"requiredField\": \"bid_102\"}", + "{\"optionalField\": \"d3\", \"requiredField\": \"bid_103\"}", + "{\"optionalField\": \"d4\", \"requiredField\": \"bid_104\"}"); + + @BeforeAll + public static void startSpark() { + TestDataFrameWrites.spark = SparkSession.builder().master("local[2]").getOrCreate(); + TestDataFrameWrites.sc = JavaSparkContext.fromSparkContext(spark.sparkContext()); + } + + @AfterAll + public static void stopSpark() { + SparkSession currentSpark = TestDataFrameWrites.spark; + TestDataFrameWrites.spark = null; + TestDataFrameWrites.sc = null; + currentSpark.stop(); + } + + @Override + protected void writeAndValidate(Schema schema) throws IOException { + Table table = createTable(schema); + writeAndValidateWithLocations(table, new File(location, "data")); + } + + @TestTemplate + public void testWriteWithCustomDataLocation() throws IOException { + File tablePropertyDataLocation = temp.resolve("test-table-property-data-dir").toFile(); + Table table = createTable(new Schema(SUPPORTED_PRIMITIVES.fields())); + table + .updateProperties() + .set(TableProperties.WRITE_DATA_LOCATION, tablePropertyDataLocation.getAbsolutePath()) + .commit(); + writeAndValidateWithLocations(table, tablePropertyDataLocation); + } + + private Table createTable(Schema schema) { + HadoopTables tables = new HadoopTables(CONF); + return tables.create(schema, PartitionSpec.unpartitioned(), location.toString()); + } + + private void writeAndValidateWithLocations(Table table, File expectedDataDir) throws IOException { + Schema tableSchema = table.schema(); // use the table schema because ids are reassigned + + table.updateProperties().set(TableProperties.DEFAULT_FILE_FORMAT, format).commit(); + + Iterable expected = RandomData.generate(tableSchema, 100, 0L); + writeData(expected, tableSchema); + + table.refresh(); + + List actual = readTable(); + + Iterator expectedIter = expected.iterator(); + Iterator actualIter = actual.iterator(); + while (expectedIter.hasNext() && actualIter.hasNext()) { + assertEqualsSafe(tableSchema.asStruct(), expectedIter.next(), actualIter.next()); + } + assertThat(actualIter.hasNext()) + .as("Both iterators should be exhausted") + .isEqualTo(expectedIter.hasNext()); + + table + .currentSnapshot() + .addedDataFiles(table.io()) + .forEach( + dataFile -> + assertThat(URI.create(dataFile.location()).getPath()) + .as( + String.format( + "File should have the parent directory %s, but has: %s.", + expectedDataDir.getAbsolutePath(), dataFile.location())) + .startsWith(expectedDataDir.getAbsolutePath())); + } + + private List readTable() { + Dataset result = spark.read().format("iceberg").load(location.toString()); + + return result.collectAsList(); + } + + private void writeData(Iterable records, Schema schema) throws IOException { + Dataset df = createDataset(records, schema); + DataFrameWriter writer = df.write().format("iceberg").mode("append"); + writer.save(location.toString()); + } + + private void writeDataWithFailOnPartition(Iterable records, Schema schema) + throws IOException, SparkException { + final int numPartitions = 10; + final int partitionToFail = new Random().nextInt(numPartitions); + MapPartitionsFunction failOnFirstPartitionFunc = + input -> { + int partitionId = TaskContext.getPartitionId(); + + if (partitionId == partitionToFail) { + throw new SparkException( + String.format("Intended exception in partition %d !", partitionId)); + } + return input; + }; + + Dataset df = + createDataset(records, schema) + .repartition(numPartitions) + .mapPartitions(failOnFirstPartitionFunc, Encoders.row(convert(schema))); + // This trick is needed because Spark 3 handles decimal overflow in RowEncoder which "changes" + // nullability of the column to "true" regardless of original nullability. + // Setting "check-nullability" option to "false" doesn't help as it fails at Spark analyzer. + Dataset convertedDf = df.sqlContext().createDataFrame(df.rdd(), convert(schema)); + DataFrameWriter writer = convertedDf.write().format("iceberg").mode("append"); + writer.save(location.toString()); + } + + private Dataset createDataset(Iterable records, Schema schema) throws IOException { + // this uses the SparkAvroReader to create a DataFrame from the list of records + // it assumes that SparkAvroReader is correct + File testFile = File.createTempFile("junit", null, temp.toFile()); + assertThat(testFile.delete()).as("Delete should succeed").isTrue(); + + try (FileAppender writer = + Avro.write(Files.localOutput(testFile)).schema(schema).named("test").build()) { + for (Record rec : records) { + writer.add(rec); + } + } + + // make sure the dataframe matches the records before moving on + List rows = Lists.newArrayList(); + try (AvroIterable reader = + Avro.read(Files.localInput(testFile)) + .createResolvingReader(SparkPlannedAvroReader::create) + .project(schema) + .build()) { + + Iterator recordIter = records.iterator(); + Iterator readIter = reader.iterator(); + while (recordIter.hasNext() && readIter.hasNext()) { + InternalRow row = readIter.next(); + assertEqualsUnsafe(schema.asStruct(), recordIter.next(), row); + rows.add(row); + } + assertThat(readIter.hasNext()) + .as("Both iterators should be exhausted") + .isEqualTo(recordIter.hasNext()); + } + + JavaRDD rdd = sc.parallelize(rows); + return spark.internalCreateDataFrame(JavaRDD.toRDD(rdd), convert(schema), false); + } + + @TestTemplate + public void testNullableWithWriteOption() throws IOException { + assumeThat(spark.version()) + .as("Spark 3 rejects writing nulls to a required column") + .startsWith("2"); + + String sourcePath = String.format("%s/nullable_poc/sourceFolder/", location); + String targetPath = String.format("%s/nullable_poc/targetFolder/", location); + + tableProperties = ImmutableMap.of(TableProperties.WRITE_DATA_LOCATION, targetPath); + + // read this and append to iceberg dataset + spark + .read() + .schema(sparkSchema) + .json(JavaSparkContext.fromSparkContext(spark.sparkContext()).parallelize(data1)) + .write() + .parquet(sourcePath); + + // this is our iceberg dataset to which we will append data + new HadoopTables(spark.sessionState().newHadoopConf()) + .create( + icebergSchema, + PartitionSpec.builderFor(icebergSchema).identity("requiredField").build(), + tableProperties, + targetPath); + + // this is the initial data inside the iceberg dataset + spark + .read() + .schema(sparkSchema) + .json(JavaSparkContext.fromSparkContext(spark.sparkContext()).parallelize(data0)) + .write() + .format("iceberg") + .mode(SaveMode.Append) + .save(targetPath); + + // read from parquet and append to iceberg w/ nullability check disabled + spark + .read() + .schema(SparkSchemaUtil.convert(icebergSchema)) + .parquet(sourcePath) + .write() + .format("iceberg") + .option(SparkWriteOptions.CHECK_NULLABILITY, false) + .mode(SaveMode.Append) + .save(targetPath); + + // read all data + List rows = spark.read().format("iceberg").load(targetPath).collectAsList(); + assumeThat(rows).as("Should contain 6 rows").hasSize(6); + } + + @TestTemplate + public void testNullableWithSparkSqlOption() throws IOException { + assumeThat(spark.version()) + .as("Spark 3 rejects writing nulls to a required column") + .startsWith("2"); + + String sourcePath = String.format("%s/nullable_poc/sourceFolder/", location); + String targetPath = String.format("%s/nullable_poc/targetFolder/", location); + + tableProperties = ImmutableMap.of(TableProperties.WRITE_DATA_LOCATION, targetPath); + + // read this and append to iceberg dataset + spark + .read() + .schema(sparkSchema) + .json(JavaSparkContext.fromSparkContext(spark.sparkContext()).parallelize(data1)) + .write() + .parquet(sourcePath); + + SparkSession newSparkSession = + SparkSession.builder() + .master("local[2]") + .appName("NullableTest") + .config(SparkSQLProperties.CHECK_NULLABILITY, false) + .getOrCreate(); + + // this is our iceberg dataset to which we will append data + new HadoopTables(newSparkSession.sessionState().newHadoopConf()) + .create( + icebergSchema, + PartitionSpec.builderFor(icebergSchema).identity("requiredField").build(), + tableProperties, + targetPath); + + // this is the initial data inside the iceberg dataset + newSparkSession + .read() + .schema(sparkSchema) + .json(JavaSparkContext.fromSparkContext(spark.sparkContext()).parallelize(data0)) + .write() + .format("iceberg") + .mode(SaveMode.Append) + .save(targetPath); + + // read from parquet and append to iceberg + newSparkSession + .read() + .schema(SparkSchemaUtil.convert(icebergSchema)) + .parquet(sourcePath) + .write() + .format("iceberg") + .mode(SaveMode.Append) + .save(targetPath); + + // read all data + List rows = newSparkSession.read().format("iceberg").load(targetPath).collectAsList(); + assumeThat(rows).as("Should contain 6 rows").hasSize(6); + } + + @TestTemplate + public void testFaultToleranceOnWrite() throws IOException { + Schema schema = new Schema(SUPPORTED_PRIMITIVES.fields()); + Table table = createTable(schema); + + Iterable records = RandomData.generate(schema, 100, 0L); + writeData(records, schema); + + table.refresh(); + + Snapshot snapshotBeforeFailingWrite = table.currentSnapshot(); + List resultBeforeFailingWrite = readTable(); + + Iterable records2 = RandomData.generate(schema, 100, 0L); + + assertThatThrownBy(() -> writeDataWithFailOnPartition(records2, schema)) + .isInstanceOf(SparkException.class); + + table.refresh(); + + Snapshot snapshotAfterFailingWrite = table.currentSnapshot(); + List resultAfterFailingWrite = readTable(); + + assertThat(snapshotBeforeFailingWrite).isEqualTo(snapshotAfterFailingWrite); + assertThat(resultBeforeFailingWrite).isEqualTo(resultAfterFailingWrite); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestDataSourceOptions.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestDataSourceOptions.java new file mode 100644 index 000000000000..c4ba96e63403 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestDataSourceOptions.java @@ -0,0 +1,521 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.io.IOException; +import java.math.RoundingMode; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.math.LongMath; +import org.apache.iceberg.spark.CommitMetadata; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.spark.TestBaseWithCatalog; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.functions; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.io.TempDir; + +public class TestDataSourceOptions extends TestBaseWithCatalog { + + private static final Configuration CONF = new Configuration(); + private static final Schema SCHEMA = + new Schema( + optional(1, "id", Types.IntegerType.get()), optional(2, "data", Types.StringType.get())); + private static SparkSession spark = null; + + @TempDir private Path temp; + + @BeforeAll + public static void startSpark() { + TestDataSourceOptions.spark = SparkSession.builder().master("local[2]").getOrCreate(); + } + + @AfterAll + public static void stopSpark() { + SparkSession currentSpark = TestDataSourceOptions.spark; + TestDataSourceOptions.spark = null; + currentSpark.stop(); + } + + @TestTemplate + public void testWriteFormatOptionOverridesTableProperties() throws IOException { + String tableLocation = temp.resolve("iceberg-table").toFile().toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + options.put(TableProperties.DEFAULT_FILE_FORMAT, "avro"); + Table table = tables.create(SCHEMA, spec, options, tableLocation); + + List expectedRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset df = spark.createDataFrame(expectedRecords, SimpleRecord.class); + df.select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, "parquet") + .mode(SaveMode.Append) + .save(tableLocation); + + try (CloseableIterable tasks = table.newScan().planFiles()) { + tasks.forEach( + task -> { + FileFormat fileFormat = FileFormat.fromFileName(task.file().location()); + assertThat(fileFormat).isEqualTo(FileFormat.PARQUET); + }); + } + } + + @TestTemplate + public void testNoWriteFormatOption() throws IOException { + String tableLocation = temp.resolve("iceberg-table").toFile().toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + options.put(TableProperties.DEFAULT_FILE_FORMAT, "avro"); + Table table = tables.create(SCHEMA, spec, options, tableLocation); + + List expectedRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset df = spark.createDataFrame(expectedRecords, SimpleRecord.class); + df.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + try (CloseableIterable tasks = table.newScan().planFiles()) { + tasks.forEach( + task -> { + FileFormat fileFormat = FileFormat.fromFileName(task.file().location()); + assertThat(fileFormat).isEqualTo(FileFormat.AVRO); + }); + } + } + + @TestTemplate + public void testHadoopOptions() throws IOException { + String tableLocation = temp.resolve("iceberg-table").toFile().toString(); + Configuration sparkHadoopConf = spark.sessionState().newHadoopConf(); + String originalDefaultFS = sparkHadoopConf.get("fs.default.name"); + + try { + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + tables.create(SCHEMA, spec, options, tableLocation); + + // set an invalid value for 'fs.default.name' in Spark Hadoop config + // to verify that 'hadoop.' data source options are propagated correctly + sparkHadoopConf.set("fs.default.name", "hdfs://localhost:9000"); + + List expectedRecords = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset originalDf = spark.createDataFrame(expectedRecords, SimpleRecord.class); + originalDf + .select("id", "data") + .write() + .format("iceberg") + .mode("append") + .option("hadoop.fs.default.name", "file:///") + .save(tableLocation); + + Dataset resultDf = + spark + .read() + .format("iceberg") + .option("hadoop.fs.default.name", "file:///") + .load(tableLocation); + List resultRecords = + resultDf.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + + assertThat(resultRecords).as("Records should match").isEqualTo(expectedRecords); + } finally { + sparkHadoopConf.set("fs.default.name", originalDefaultFS); + } + } + + @TestTemplate + public void testSplitOptionsOverridesTableProperties() throws IOException { + String tableLocation = temp.resolve("iceberg-table").toFile().toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + options.put(TableProperties.SPLIT_SIZE, String.valueOf(128L * 1024 * 1024)); // 128Mb + options.put( + TableProperties.DEFAULT_FILE_FORMAT, + String.valueOf(FileFormat.AVRO)); // Arbitrarily splittable + Table icebergTable = tables.create(SCHEMA, spec, options, tableLocation); + + List expectedRecords = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset originalDf = spark.createDataFrame(expectedRecords, SimpleRecord.class); + originalDf + .select("id", "data") + .repartition(1) + .write() + .format("iceberg") + .mode("append") + .save(tableLocation); + + List files = + Lists.newArrayList(icebergTable.currentSnapshot().addedDataFiles(icebergTable.io())); + assertThat(files).as("Should have written 1 file").hasSize(1); + + long fileSize = files.get(0).fileSizeInBytes(); + long splitSize = LongMath.divide(fileSize, 2, RoundingMode.CEILING); + + Dataset resultDf = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SPLIT_SIZE, String.valueOf(splitSize)) + .load(tableLocation); + + assertThat(resultDf.javaRDD().getNumPartitions()) + .as("Spark partitions should match") + .isEqualTo(2); + } + + @TestTemplate + public void testIncrementalScanOptions() throws IOException { + String tableLocation = temp.resolve("iceberg-table").toFile().toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + Table table = tables.create(SCHEMA, spec, options, tableLocation); + + List expectedRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c"), + new SimpleRecord(4, "d")); + for (SimpleRecord record : expectedRecords) { + Dataset originalDf = + spark.createDataFrame(Lists.newArrayList(record), SimpleRecord.class); + originalDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + } + List snapshotIds = SnapshotUtil.currentAncestorIds(table); + + // start-snapshot-id and snapshot-id are both configured. + assertThatThrownBy( + () -> + spark + .read() + .format("iceberg") + .option("snapshot-id", snapshotIds.get(3).toString()) + .option("start-snapshot-id", snapshotIds.get(3).toString()) + .load(tableLocation) + .explain()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Cannot set start-snapshot-id and end-snapshot-id for incremental scans when either snapshot-id or as-of-timestamp is set"); + + // end-snapshot-id and as-of-timestamp are both configured. + assertThatThrownBy( + () -> + spark + .read() + .format("iceberg") + .option( + SparkReadOptions.AS_OF_TIMESTAMP, + Long.toString(table.snapshot(snapshotIds.get(3)).timestampMillis())) + .option("end-snapshot-id", snapshotIds.get(2).toString()) + .load(tableLocation) + .explain()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Cannot set start-snapshot-id and end-snapshot-id for incremental scans when either snapshot-id or as-of-timestamp is set"); + + // only end-snapshot-id is configured. + assertThatThrownBy( + () -> + spark + .read() + .format("iceberg") + .option("end-snapshot-id", snapshotIds.get(2).toString()) + .load(tableLocation) + .explain()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Cannot set only end-snapshot-id for incremental scans. Please, set start-snapshot-id too."); + + // test (1st snapshot, current snapshot] incremental scan. + Dataset unboundedIncrementalResult = + spark + .read() + .format("iceberg") + .option("start-snapshot-id", snapshotIds.get(3).toString()) + .load(tableLocation); + List result1 = + unboundedIncrementalResult + .orderBy("id") + .as(Encoders.bean(SimpleRecord.class)) + .collectAsList(); + assertThat(result1).as("Records should match").isEqualTo(expectedRecords.subList(1, 4)); + assertThat(unboundedIncrementalResult.count()) + .as("Unprocessed count should match record count") + .isEqualTo(3); + + Row row1 = unboundedIncrementalResult.agg(functions.min("id"), functions.max("id")).head(); + assertThat(row1.getInt(0)).as("min value should match").isEqualTo(2); + assertThat(row1.getInt(1)).as("max value should match").isEqualTo(4); + + // test (2nd snapshot, 3rd snapshot] incremental scan. + Dataset incrementalResult = + spark + .read() + .format("iceberg") + .option("start-snapshot-id", snapshotIds.get(2).toString()) + .option("end-snapshot-id", snapshotIds.get(1).toString()) + .load(tableLocation); + List result2 = + incrementalResult.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + assertThat(result2).as("Records should match").isEqualTo(expectedRecords.subList(2, 3)); + assertThat(incrementalResult.count()) + .as("Unprocessed count should match record count") + .isEqualTo(1); + + Row row2 = incrementalResult.agg(functions.min("id"), functions.max("id")).head(); + assertThat(row2.getInt(0)).as("min value should match").isEqualTo(3); + assertThat(row2.getInt(1)).as("max value should match").isEqualTo(3); + } + + @TestTemplate + public void testMetadataSplitSizeOptionOverrideTableProperties() throws IOException { + String tableLocation = temp.resolve("iceberg-table").toFile().toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + Table table = tables.create(SCHEMA, spec, options, tableLocation); + + List expectedRecords = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset originalDf = spark.createDataFrame(expectedRecords, SimpleRecord.class); + // produce 1st manifest + originalDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + // produce 2nd manifest + originalDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + List manifests = table.currentSnapshot().allManifests(table.io()); + + assertThat(manifests).as("Must be 2 manifests").hasSize(2); + + // set the target metadata split size so each manifest ends up in a separate split + table + .updateProperties() + .set(TableProperties.METADATA_SPLIT_SIZE, String.valueOf(manifests.get(0).length())) + .commit(); + + Dataset entriesDf = spark.read().format("iceberg").load(tableLocation + "#entries"); + assertThat(entriesDf.javaRDD().getNumPartitions()).as("Num partitions must match").isEqualTo(2); + + // override the table property using options + entriesDf = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SPLIT_SIZE, String.valueOf(128 * 1024 * 1024)) + .load(tableLocation + "#entries"); + assertThat(entriesDf.javaRDD().getNumPartitions()).as("Num partitions must match").isEqualTo(1); + } + + @TestTemplate + public void testDefaultMetadataSplitSize() throws IOException { + String tableLocation = temp.resolve("iceberg-table").toFile().toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + Table icebergTable = tables.create(SCHEMA, spec, options, tableLocation); + + List expectedRecords = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset originalDf = spark.createDataFrame(expectedRecords, SimpleRecord.class); + originalDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + int splitSize = (int) TableProperties.METADATA_SPLIT_SIZE_DEFAULT; // 32MB split size + + int expectedSplits = + ((int) + tables + .load(tableLocation + "#entries") + .currentSnapshot() + .allManifests(icebergTable.io()) + .get(0) + .length() + + splitSize + - 1) + / splitSize; + + Dataset metadataDf = spark.read().format("iceberg").load(tableLocation + "#entries"); + + int partitionNum = metadataDf.javaRDD().getNumPartitions(); + assertThat(partitionNum).as("Spark partitions should match").isEqualTo(expectedSplits); + } + + @TestTemplate + public void testExtraSnapshotMetadata() throws IOException { + String tableLocation = temp.resolve("iceberg-table").toFile().toString(); + HadoopTables tables = new HadoopTables(CONF); + tables.create(SCHEMA, PartitionSpec.unpartitioned(), Maps.newHashMap(), tableLocation); + + List expectedRecords = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset originalDf = spark.createDataFrame(expectedRecords, SimpleRecord.class); + originalDf + .select("id", "data") + .write() + .format("iceberg") + .mode("append") + .option(SparkWriteOptions.SNAPSHOT_PROPERTY_PREFIX + ".extra-key", "someValue") + .option(SparkWriteOptions.SNAPSHOT_PROPERTY_PREFIX + ".another-key", "anotherValue") + .save(tableLocation); + + Table table = tables.load(tableLocation); + + assertThat(table.currentSnapshot().summary()) + .containsEntry("extra-key", "someValue") + .containsEntry("another-key", "anotherValue"); + } + + @TestTemplate + public void testExtraSnapshotMetadataWithSQL() throws InterruptedException, IOException { + String tableLocation = temp.resolve("iceberg-table").toFile().toString(); + HadoopTables tables = new HadoopTables(CONF); + + Table table = + tables.create(SCHEMA, PartitionSpec.unpartitioned(), Maps.newHashMap(), tableLocation); + + List expectedRecords = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset originalDf = spark.createDataFrame(expectedRecords, SimpleRecord.class); + originalDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + spark.read().format("iceberg").load(tableLocation).createOrReplaceTempView("target"); + Thread writerThread = + new Thread( + () -> { + Map properties = + ImmutableMap.of( + "writer-thread", + String.valueOf(Thread.currentThread().getName()), + SnapshotSummary.EXTRA_METADATA_PREFIX + "extra-key", + "someValue", + SnapshotSummary.EXTRA_METADATA_PREFIX + "another-key", + "anotherValue"); + CommitMetadata.withCommitProperties( + properties, + () -> { + spark.sql("INSERT INTO target VALUES (3, 'c'), (4, 'd')"); + return 0; + }, + RuntimeException.class); + }); + writerThread.setName("test-extra-commit-message-writer-thread"); + writerThread.start(); + writerThread.join(); + + List snapshots = Lists.newArrayList(table.snapshots()); + assertThat(snapshots).hasSize(2); + assertThat(snapshots.get(0).summary().get("writer-thread")).isNull(); + assertThat(snapshots.get(1).summary()) + .containsEntry("writer-thread", "test-extra-commit-message-writer-thread") + .containsEntry("extra-key", "someValue") + .containsEntry("another-key", "anotherValue"); + } + + @TestTemplate + public void testExtraSnapshotMetadataWithDelete() + throws InterruptedException, NoSuchTableException { + spark.sessionState().conf().setConfString("spark.sql.shuffle.partitions", "1"); + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + List expectedRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset originalDf = spark.createDataFrame(expectedRecords, SimpleRecord.class); + originalDf.repartition(5, new Column("data")).select("id", "data").writeTo(tableName).append(); + Thread writerThread = + new Thread( + () -> { + Map properties = + ImmutableMap.of( + "writer-thread", + String.valueOf(Thread.currentThread().getName()), + SnapshotSummary.EXTRA_METADATA_PREFIX + "extra-key", + "someValue", + SnapshotSummary.EXTRA_METADATA_PREFIX + "another-key", + "anotherValue"); + CommitMetadata.withCommitProperties( + properties, + () -> { + spark.sql("DELETE FROM " + tableName + " where id = 1"); + return 0; + }, + RuntimeException.class); + }); + writerThread.setName("test-extra-commit-message-delete-thread"); + writerThread.start(); + writerThread.join(); + + Table table = validationCatalog.loadTable(tableIdent); + List snapshots = Lists.newArrayList(table.snapshots()); + + assertThat(snapshots).hasSize(2); + assertThat(snapshots.get(0).summary().get("writer-thread")).isNull(); + assertThat(snapshots.get(1).summary()) + .containsEntry("writer-thread", "test-extra-commit-message-delete-thread") + .containsEntry("extra-key", "someValue") + .containsEntry("another-key", "anotherValue"); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestFilteredScan.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestFilteredScan.java new file mode 100644 index 000000000000..348173596e46 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestFilteredScan.java @@ -0,0 +1,666 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.Files.localOutput; +import static org.apache.iceberg.PlanningMode.DISTRIBUTED; +import static org.apache.iceberg.PlanningMode.LOCAL; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.time.OffsetDateTime; +import java.util.Arrays; +import java.util.List; +import java.util.UUID; +import java.util.stream.Collectors; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.PlanningMode; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.data.GenericAppenderFactory; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.spark.data.GenericsHelpers; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.connector.expressions.filter.Predicate; +import org.apache.spark.sql.connector.read.Batch; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.connector.read.SupportsPushDownV2Filters; +import org.apache.spark.sql.sources.And; +import org.apache.spark.sql.sources.EqualTo; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.LessThan; +import org.apache.spark.sql.sources.Not; +import org.apache.spark.sql.sources.StringStartsWith; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestFilteredScan { + private static final Configuration CONF = new Configuration(); + private static final HadoopTables TABLES = new HadoopTables(CONF); + + private static final Schema SCHEMA = + new Schema( + Types.NestedField.required(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "ts", Types.TimestampType.withZone()), + Types.NestedField.optional(3, "data", Types.StringType.get())); + + private static final PartitionSpec BUCKET_BY_ID = + PartitionSpec.builderFor(SCHEMA).bucket("id", 4).build(); + + private static final PartitionSpec PARTITION_BY_DAY = + PartitionSpec.builderFor(SCHEMA).day("ts").build(); + + private static final PartitionSpec PARTITION_BY_HOUR = + PartitionSpec.builderFor(SCHEMA).hour("ts").build(); + + private static final PartitionSpec PARTITION_BY_DATA = + PartitionSpec.builderFor(SCHEMA).identity("data").build(); + + private static final PartitionSpec PARTITION_BY_ID = + PartitionSpec.builderFor(SCHEMA).identity("id").build(); + + private static SparkSession spark = null; + + @BeforeAll + public static void startSpark() { + TestFilteredScan.spark = SparkSession.builder().master("local[2]").getOrCreate(); + } + + @AfterAll + public static void stopSpark() { + SparkSession currentSpark = TestFilteredScan.spark; + TestFilteredScan.spark = null; + currentSpark.stop(); + } + + @TempDir private Path temp; + + @Parameter(index = 0) + private String format; + + @Parameter(index = 1) + private boolean vectorized; + + @Parameter(index = 2) + private PlanningMode planningMode; + + @Parameters(name = "format = {0}, vectorized = {1}, planningMode = {2}") + public static Object[][] parameters() { + return new Object[][] { + {"parquet", false, LOCAL}, + {"parquet", true, DISTRIBUTED}, + {"avro", false, LOCAL}, + {"orc", false, DISTRIBUTED}, + {"orc", true, LOCAL} + }; + } + + private File parent = null; + private File unpartitioned = null; + private List records = null; + + @BeforeEach + public void writeUnpartitionedTable() throws IOException { + this.parent = temp.resolve("TestFilteredScan").toFile(); + this.unpartitioned = new File(parent, "unpartitioned"); + File dataFolder = new File(unpartitioned, "data"); + assertThat(dataFolder.mkdirs()).as("Mkdir should succeed").isTrue(); + + Table table = + TABLES.create( + SCHEMA, + PartitionSpec.unpartitioned(), + ImmutableMap.of( + TableProperties.DATA_PLANNING_MODE, + planningMode.modeName(), + TableProperties.DELETE_PLANNING_MODE, + planningMode.modeName()), + unpartitioned.toString()); + Schema tableSchema = table.schema(); // use the table schema because ids are reassigned + + FileFormat fileFormat = FileFormat.fromString(format); + + File testFile = new File(dataFolder, fileFormat.addExtension(UUID.randomUUID().toString())); + + this.records = testRecords(tableSchema); + + try (FileAppender writer = + new GenericAppenderFactory(tableSchema).newAppender(localOutput(testFile), fileFormat)) { + writer.addAll(records); + } + + DataFile file = + DataFiles.builder(PartitionSpec.unpartitioned()) + .withRecordCount(records.size()) + .withFileSizeInBytes(testFile.length()) + .withPath(testFile.toString()) + .build(); + + table.newAppend().appendFile(file).commit(); + } + + @TestTemplate + public void testUnpartitionedIDFilters() { + CaseInsensitiveStringMap options = + new CaseInsensitiveStringMap(ImmutableMap.of("path", unpartitioned.toString())); + SparkScanBuilder builder = + new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + for (int i = 0; i < 10; i += 1) { + pushFilters(builder, EqualTo.apply("id", i)); + Batch scan = builder.build().toBatch(); + + InputPartition[] partitions = scan.planInputPartitions(); + assertThat(partitions).as("Should only create one task for a small file").hasSize(1); + + // validate row filtering + assertEqualsSafe( + SCHEMA.asStruct(), expected(i), read(unpartitioned.toString(), vectorized, "id = " + i)); + } + } + + @TestTemplate + public void testUnpartitionedCaseInsensitiveIDFilters() { + CaseInsensitiveStringMap options = + new CaseInsensitiveStringMap(ImmutableMap.of("path", unpartitioned.toString())); + + // set spark.sql.caseSensitive to false + String caseSensitivityBeforeTest = TestFilteredScan.spark.conf().get("spark.sql.caseSensitive"); + TestFilteredScan.spark.conf().set("spark.sql.caseSensitive", "false"); + + try { + + for (int i = 0; i < 10; i += 1) { + SparkScanBuilder builder = + new SparkScanBuilder(spark, TABLES.load(options.get("path")), options) + .caseSensitive(false); + + pushFilters( + builder, + EqualTo.apply("ID", i)); // note lower(ID) == lower(id), so there must be a match + Batch scan = builder.build().toBatch(); + + InputPartition[] tasks = scan.planInputPartitions(); + assertThat(tasks).as("Should only create one task for a small file").hasSize(1); + + // validate row filtering + assertEqualsSafe( + SCHEMA.asStruct(), + expected(i), + read(unpartitioned.toString(), vectorized, "id = " + i)); + } + } finally { + // return global conf to previous state + TestFilteredScan.spark.conf().set("spark.sql.caseSensitive", caseSensitivityBeforeTest); + } + } + + @TestTemplate + public void testUnpartitionedTimestampFilter() { + CaseInsensitiveStringMap options = + new CaseInsensitiveStringMap(ImmutableMap.of("path", unpartitioned.toString())); + + SparkScanBuilder builder = + new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + pushFilters(builder, LessThan.apply("ts", "2017-12-22T00:00:00+00:00")); + Batch scan = builder.build().toBatch(); + + InputPartition[] tasks = scan.planInputPartitions(); + assertThat(tasks).as("Should only create one task for a small file").hasSize(1); + + assertEqualsSafe( + SCHEMA.asStruct(), + expected(5, 6, 7, 8, 9), + read( + unpartitioned.toString(), + vectorized, + "ts < cast('2017-12-22 00:00:00+00:00' as timestamp)")); + } + + @TestTemplate + public void testBucketPartitionedIDFilters() { + Table table = buildPartitionedTable("bucketed_by_id", BUCKET_BY_ID); + CaseInsensitiveStringMap options = + new CaseInsensitiveStringMap(ImmutableMap.of("path", table.location())); + + Batch unfiltered = + new SparkScanBuilder(spark, TABLES.load(options.get("path")), options).build().toBatch(); + assertThat(unfiltered.planInputPartitions()) + .as("Unfiltered table should created 4 read tasks") + .hasSize(4); + + for (int i = 0; i < 10; i += 1) { + SparkScanBuilder builder = + new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + pushFilters(builder, EqualTo.apply("id", i)); + Batch scan = builder.build().toBatch(); + + InputPartition[] tasks = scan.planInputPartitions(); + + // validate predicate push-down + assertThat(tasks).as("Should only create one task for a single bucket").hasSize(1); + + // validate row filtering + assertEqualsSafe( + SCHEMA.asStruct(), expected(i), read(table.location(), vectorized, "id = " + i)); + } + } + + @SuppressWarnings("checkstyle:AvoidNestedBlocks") + @TestTemplate + public void testDayPartitionedTimestampFilters() { + Table table = buildPartitionedTable("partitioned_by_day", PARTITION_BY_DAY); + CaseInsensitiveStringMap options = + new CaseInsensitiveStringMap(ImmutableMap.of("path", table.location())); + Batch unfiltered = + new SparkScanBuilder(spark, TABLES.load(options.get("path")), options).build().toBatch(); + + assertThat(unfiltered.planInputPartitions()) + .as("Unfiltered table should created 2 read tasks") + .hasSize(2); + + { + SparkScanBuilder builder = + new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + pushFilters(builder, LessThan.apply("ts", "2017-12-22T00:00:00+00:00")); + Batch scan = builder.build().toBatch(); + + InputPartition[] tasks = scan.planInputPartitions(); + assertThat(tasks).as("Should create one task for 2017-12-21").hasSize(1); + + assertEqualsSafe( + SCHEMA.asStruct(), + expected(5, 6, 7, 8, 9), + read( + table.location(), vectorized, "ts < cast('2017-12-22 00:00:00+00:00' as timestamp)")); + } + + { + SparkScanBuilder builder = + new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + pushFilters( + builder, + And.apply( + GreaterThan.apply("ts", "2017-12-22T06:00:00+00:00"), + LessThan.apply("ts", "2017-12-22T08:00:00+00:00"))); + Batch scan = builder.build().toBatch(); + + InputPartition[] tasks = scan.planInputPartitions(); + assertThat(tasks).as("Should create one task for 2017-12-22").hasSize(1); + + assertEqualsSafe( + SCHEMA.asStruct(), + expected(1, 2), + read( + table.location(), + vectorized, + "ts > cast('2017-12-22 06:00:00+00:00' as timestamp) and " + + "ts < cast('2017-12-22 08:00:00+00:00' as timestamp)")); + } + } + + @SuppressWarnings("checkstyle:AvoidNestedBlocks") + @TestTemplate + public void testHourPartitionedTimestampFilters() { + Table table = buildPartitionedTable("partitioned_by_hour", PARTITION_BY_HOUR); + + CaseInsensitiveStringMap options = + new CaseInsensitiveStringMap(ImmutableMap.of("path", table.location())); + Batch unfiltered = + new SparkScanBuilder(spark, TABLES.load(options.get("path")), options).build().toBatch(); + + assertThat(unfiltered.planInputPartitions()) + .as("Unfiltered table should created 9 read tasks") + .hasSize(9); + + { + SparkScanBuilder builder = + new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + pushFilters(builder, LessThan.apply("ts", "2017-12-22T00:00:00+00:00")); + Batch scan = builder.build().toBatch(); + + InputPartition[] tasks = scan.planInputPartitions(); + assertThat(tasks).as("Should create 4 tasks for 2017-12-21: 15, 17, 21, 22").hasSize(4); + + assertEqualsSafe( + SCHEMA.asStruct(), + expected(8, 9, 7, 6, 5), + read( + table.location(), vectorized, "ts < cast('2017-12-22 00:00:00+00:00' as timestamp)")); + } + + { + SparkScanBuilder builder = + new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + pushFilters( + builder, + And.apply( + GreaterThan.apply("ts", "2017-12-22T06:00:00+00:00"), + LessThan.apply("ts", "2017-12-22T08:00:00+00:00"))); + Batch scan = builder.build().toBatch(); + + InputPartition[] tasks = scan.planInputPartitions(); + assertThat(tasks).as("Should create 2 tasks for 2017-12-22: 6, 7").hasSize(2); + + assertEqualsSafe( + SCHEMA.asStruct(), + expected(2, 1), + read( + table.location(), + vectorized, + "ts > cast('2017-12-22 06:00:00+00:00' as timestamp) and " + + "ts < cast('2017-12-22 08:00:00+00:00' as timestamp)")); + } + } + + @SuppressWarnings("checkstyle:AvoidNestedBlocks") + @TestTemplate + public void testFilterByNonProjectedColumn() { + { + Schema actualProjection = SCHEMA.select("id", "data"); + List expected = Lists.newArrayList(); + for (Record rec : expected(5, 6, 7, 8, 9)) { + expected.add(projectFlat(actualProjection, rec)); + } + + assertEqualsSafe( + actualProjection.asStruct(), + expected, + read( + unpartitioned.toString(), + vectorized, + "ts < cast('2017-12-22 00:00:00+00:00' as timestamp)", + "id", + "data")); + } + + { + // only project id: ts will be projected because of the filter, but data will not be included + + Schema actualProjection = SCHEMA.select("id"); + List expected = Lists.newArrayList(); + for (Record rec : expected(1, 2)) { + expected.add(projectFlat(actualProjection, rec)); + } + + assertEqualsSafe( + actualProjection.asStruct(), + expected, + read( + unpartitioned.toString(), + vectorized, + "ts > cast('2017-12-22 06:00:00+00:00' as timestamp) and " + + "ts < cast('2017-12-22 08:00:00+00:00' as timestamp)", + "id")); + } + } + + @TestTemplate + public void testPartitionedByDataStartsWithFilter() { + Table table = buildPartitionedTable("partitioned_by_data", PARTITION_BY_DATA); + CaseInsensitiveStringMap options = + new CaseInsensitiveStringMap(ImmutableMap.of("path", table.location())); + + SparkScanBuilder builder = + new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + pushFilters(builder, new StringStartsWith("data", "junc")); + Batch scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions()).hasSize(1); + } + + @TestTemplate + public void testPartitionedByDataNotStartsWithFilter() { + Table table = buildPartitionedTable("partitioned_by_data", PARTITION_BY_DATA); + CaseInsensitiveStringMap options = + new CaseInsensitiveStringMap(ImmutableMap.of("path", table.location())); + + SparkScanBuilder builder = + new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + pushFilters(builder, new Not(new StringStartsWith("data", "junc"))); + Batch scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions()).hasSize(9); + } + + @TestTemplate + public void testPartitionedByIdStartsWith() { + Table table = buildPartitionedTable("partitioned_by_id", PARTITION_BY_ID); + + CaseInsensitiveStringMap options = + new CaseInsensitiveStringMap(ImmutableMap.of("path", table.location())); + + SparkScanBuilder builder = + new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + pushFilters(builder, new StringStartsWith("data", "junc")); + Batch scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions()).hasSize(1); + } + + @TestTemplate + public void testPartitionedByIdNotStartsWith() { + Table table = buildPartitionedTable("partitioned_by_id", PARTITION_BY_ID); + + CaseInsensitiveStringMap options = + new CaseInsensitiveStringMap(ImmutableMap.of("path", table.location())); + + SparkScanBuilder builder = + new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + pushFilters(builder, new Not(new StringStartsWith("data", "junc"))); + Batch scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions()).hasSize(9); + } + + @TestTemplate + public void testUnpartitionedStartsWith() { + Dataset df = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(unpartitioned.toString()); + + List matchedData = + df.select("data").where("data LIKE 'jun%'").as(Encoders.STRING()).collectAsList(); + + assertThat(matchedData).hasSize(1); + assertThat(matchedData.get(0)).isEqualTo("junction"); + } + + @TestTemplate + public void testUnpartitionedNotStartsWith() { + Dataset df = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(unpartitioned.toString()); + + List matchedData = + df.select("data").where("data NOT LIKE 'jun%'").as(Encoders.STRING()).collectAsList(); + + List expected = + testRecords(SCHEMA).stream() + .map(r -> r.getField("data").toString()) + .filter(d -> !d.startsWith("jun")) + .collect(Collectors.toList()); + + assertThat(matchedData).hasSize(9); + assertThat(Sets.newHashSet(matchedData)).isEqualTo(Sets.newHashSet(expected)); + } + + private static Record projectFlat(Schema projection, Record record) { + Record result = GenericRecord.create(projection); + List fields = projection.asStruct().fields(); + for (int i = 0; i < fields.size(); i += 1) { + Types.NestedField field = fields.get(i); + result.set(i, record.getField(field.name())); + } + return result; + } + + public static void assertEqualsUnsafe( + Types.StructType struct, List expected, List actual) { + // TODO: match records by ID + int numRecords = Math.min(expected.size(), actual.size()); + for (int i = 0; i < numRecords; i += 1) { + GenericsHelpers.assertEqualsUnsafe(struct, expected.get(i), actual.get(i)); + } + assertThat(actual).as("Number of results should match expected").hasSameSizeAs(expected); + } + + public static void assertEqualsSafe( + Types.StructType struct, List expected, List actual) { + // TODO: match records by ID + int numRecords = Math.min(expected.size(), actual.size()); + for (int i = 0; i < numRecords; i += 1) { + GenericsHelpers.assertEqualsSafe(struct, expected.get(i), actual.get(i)); + } + assertThat(actual).as("Number of results should match expected").hasSameSizeAs(expected); + } + + private List expected(int... ordinals) { + List expected = Lists.newArrayListWithExpectedSize(ordinals.length); + for (int ord : ordinals) { + expected.add(records.get(ord)); + } + return expected; + } + + private void pushFilters(ScanBuilder scan, Filter... filters) { + assertThat(scan).isInstanceOf(SupportsPushDownV2Filters.class); + SupportsPushDownV2Filters filterable = (SupportsPushDownV2Filters) scan; + filterable.pushPredicates(Arrays.stream(filters).map(Filter::toV2).toArray(Predicate[]::new)); + } + + private Table buildPartitionedTable(String desc, PartitionSpec spec) { + File location = new File(parent, desc); + Table table = TABLES.create(SCHEMA, spec, location.toString()); + + // Do not combine or split files because the tests expect a split per partition. + // A target split size of 2048 helps us achieve that. + table.updateProperties().set("read.split.target-size", "2048").commit(); + + // copy the unpartitioned table into the partitioned table to produce the partitioned data + Dataset allRows = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(unpartitioned.toString()); + + // disable fanout writers to locally order records for future verifications + allRows + .write() + .option(SparkWriteOptions.FANOUT_ENABLED, "false") + .format("iceberg") + .mode("append") + .save(table.location()); + + table.refresh(); + + return table; + } + + private List testRecords(Schema schema) { + return Lists.newArrayList( + record(schema, 0L, parse("2017-12-22T09:20:44.294658+00:00"), "junction"), + record(schema, 1L, parse("2017-12-22T07:15:34.582910+00:00"), "alligator"), + record(schema, 2L, parse("2017-12-22T06:02:09.243857+00:00"), ""), + record(schema, 3L, parse("2017-12-22T03:10:11.134509+00:00"), "clapping"), + record(schema, 4L, parse("2017-12-22T00:34:00.184671+00:00"), "brush"), + record(schema, 5L, parse("2017-12-21T22:20:08.935889+00:00"), "trap"), + record(schema, 6L, parse("2017-12-21T21:55:30.589712+00:00"), "element"), + record(schema, 7L, parse("2017-12-21T17:31:14.532797+00:00"), "limited"), + record(schema, 8L, parse("2017-12-21T15:21:51.237521+00:00"), "global"), + record(schema, 9L, parse("2017-12-21T15:02:15.230570+00:00"), "goldfish")); + } + + private static List read(String table, boolean vectorized, String expr) { + return read(table, vectorized, expr, "*"); + } + + private static List read( + String table, boolean vectorized, String expr, String select0, String... selectN) { + Dataset dataset = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(table) + .filter(expr) + .select(select0, selectN); + return dataset.collectAsList(); + } + + private static OffsetDateTime parse(String timestamp) { + return OffsetDateTime.parse(timestamp); + } + + private static Record record(Schema schema, Object... values) { + Record rec = GenericRecord.create(schema); + for (int i = 0; i < values.length; i += 1) { + rec.set(i, values[i]); + } + return rec; + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestForwardCompatibility.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestForwardCompatibility.java new file mode 100644 index 000000000000..f59c77179a9c --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestForwardCompatibility.java @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.Files.localInput; +import static org.apache.iceberg.Files.localOutput; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.TimeoutException; +import org.apache.avro.generic.GenericData; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.ManifestFiles; +import org.apache.iceberg.ManifestWriter; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.execution.streaming.MemoryStream; +import org.apache.spark.sql.streaming.StreamingQuery; +import org.apache.spark.sql.streaming.StreamingQueryException; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import scala.Option; +import scala.collection.JavaConverters; + +public class TestForwardCompatibility { + private static final Configuration CONF = new Configuration(); + + private static final Schema SCHEMA = + new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get())); + + // create a spec for the schema that uses a "zero" transform that produces all 0s + private static final PartitionSpec UNKNOWN_SPEC = + org.apache.iceberg.TestHelpers.newExpectedSpecBuilder() + .withSchema(SCHEMA) + .withSpecId(0) + .addField("zero", 1, "id_zero") + .build(); + // create a fake spec to use to write table metadata + private static final PartitionSpec FAKE_SPEC = + org.apache.iceberg.TestHelpers.newExpectedSpecBuilder() + .withSchema(SCHEMA) + .withSpecId(0) + .addField("identity", 1, "id_zero") + .build(); + + @TempDir private Path temp; + + private static SparkSession spark = null; + + @BeforeAll + public static void startSpark() { + TestForwardCompatibility.spark = SparkSession.builder().master("local[2]").getOrCreate(); + } + + @AfterAll + public static void stopSpark() { + SparkSession currentSpark = TestForwardCompatibility.spark; + TestForwardCompatibility.spark = null; + currentSpark.stop(); + } + + @Test + public void testSparkWriteFailsUnknownTransform() throws IOException { + File parent = temp.resolve("avro").toFile(); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + dataFolder.mkdirs(); + + HadoopTables tables = new HadoopTables(CONF); + tables.create(SCHEMA, UNKNOWN_SPEC, location.toString()); + + List expected = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + + Dataset df = spark.createDataFrame(expected, SimpleRecord.class); + + assertThatThrownBy( + () -> + df.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(location.toString())) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessageContaining("Cannot write using unsupported transforms: zero"); + } + + @Test + public void testSparkStreamingWriteFailsUnknownTransform() throws IOException, TimeoutException { + File parent = temp.resolve("avro").toFile(); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + dataFolder.mkdirs(); + File checkpoint = new File(parent, "checkpoint"); + checkpoint.mkdirs(); + + HadoopTables tables = new HadoopTables(CONF); + tables.create(SCHEMA, UNKNOWN_SPEC, location.toString()); + + MemoryStream inputStream = newMemoryStream(1, spark.sqlContext(), Encoders.INT()); + StreamingQuery query = + inputStream + .toDF() + .selectExpr("value AS id", "CAST (value AS STRING) AS data") + .writeStream() + .outputMode("append") + .format("iceberg") + .option("checkpointLocation", checkpoint.toString()) + .option("path", location.toString()) + .start(); + + List batch1 = Lists.newArrayList(1, 2); + send(batch1, inputStream); + + assertThatThrownBy(query::processAllAvailable) + .isInstanceOf(StreamingQueryException.class) + .hasMessageContaining("Cannot write using unsupported transforms: zero"); + } + + @Test + public void testSparkCanReadUnknownTransform() throws IOException { + File parent = temp.resolve("avro").toFile(); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + dataFolder.mkdirs(); + + HadoopTables tables = new HadoopTables(CONF); + Table table = tables.create(SCHEMA, UNKNOWN_SPEC, location.toString()); + + // enable snapshot inheritance to avoid rewriting the manifest with an unknown transform + table.updateProperties().set(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, "true").commit(); + + List expected = RandomData.generateList(table.schema(), 100, 1L); + + File parquetFile = + new File(dataFolder, FileFormat.PARQUET.addExtension(UUID.randomUUID().toString())); + FileAppender writer = + Parquet.write(localOutput(parquetFile)).schema(table.schema()).build(); + try { + writer.addAll(expected); + } finally { + writer.close(); + } + + DataFile file = + DataFiles.builder(FAKE_SPEC) + .withInputFile(localInput(parquetFile)) + .withMetrics(writer.metrics()) + .withPartitionPath("id_zero=0") + .build(); + + OutputFile manifestFile = localOutput(FileFormat.AVRO.addExtension(temp.toFile().toString())); + ManifestWriter manifestWriter = ManifestFiles.write(FAKE_SPEC, manifestFile); + try { + manifestWriter.add(file); + } finally { + manifestWriter.close(); + } + + table.newFastAppend().appendManifest(manifestWriter.toManifestFile()).commit(); + + Dataset df = spark.read().format("iceberg").load(location.toString()); + + List rows = df.collectAsList(); + assertThat(rows).as("Should contain 100 rows").hasSize(100); + + for (int i = 0; i < expected.size(); i += 1) { + TestHelpers.assertEqualsSafe(table.schema().asStruct(), expected.get(i), rows.get(i)); + } + } + + private MemoryStream newMemoryStream(int id, SQLContext sqlContext, Encoder encoder) { + return new MemoryStream<>(id, sqlContext, Option.empty(), encoder); + } + + private void send(List records, MemoryStream stream) { + stream.addData(JavaConverters.asScalaBuffer(records)); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSource.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSource.java new file mode 100644 index 000000000000..a850275118db --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSource.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +public class TestIcebergSource extends IcebergSource { + @Override + public String shortName() { + return "iceberg-test"; + } + + @Override + public Identifier extractIdentifier(CaseInsensitiveStringMap options) { + TableIdentifier ti = TableIdentifier.parse(options.get("iceberg.table.name")); + return Identifier.of(ti.namespace().levels(), ti.name()); + } + + @Override + public String extractCatalog(CaseInsensitiveStringMap options) { + return SparkSession.active().sessionState().catalogManager().currentCatalog().name(); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceHadoopTables.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceHadoopTables.java new file mode 100644 index 000000000000..35d6e119e86f --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceHadoopTables.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.io.File; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.hadoop.HadoopTables; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.io.TempDir; + +public class TestIcebergSourceHadoopTables extends TestIcebergSourceTablesBase { + + private static final HadoopTables TABLES = new HadoopTables(new Configuration()); + @TempDir private File tableDir; + String tableLocation = null; + + @BeforeEach + public void setupTable() throws Exception { + this.tableLocation = tableDir.toURI().toString(); + } + + @Override + public Table createTable( + TableIdentifier ident, Schema schema, PartitionSpec spec, Map properties) { + return TABLES.create(schema, spec, properties, tableLocation); + } + + @Override + public void dropTable(TableIdentifier ident) { + TABLES.dropTable(tableLocation); + } + + @Override + public Table loadTable(TableIdentifier ident, String entriesSuffix) { + return TABLES.load(loadLocation(ident, entriesSuffix)); + } + + @Override + public String loadLocation(TableIdentifier ident, String entriesSuffix) { + return String.format("%s#%s", loadLocation(ident), entriesSuffix); + } + + @Override + public String loadLocation(TableIdentifier ident) { + return tableLocation; + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceHiveTables.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceHiveTables.java new file mode 100644 index 000000000000..9120bbcc35a3 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceHiveTables.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.util.Map; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.TableIdentifier; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; + +public class TestIcebergSourceHiveTables extends TestIcebergSourceTablesBase { + + private static TableIdentifier currentIdentifier; + + @BeforeAll + public static void start() { + Namespace db = Namespace.of("db"); + if (!catalog.namespaceExists(db)) { + catalog.createNamespace(db); + } + } + + @AfterEach + public void dropTable() throws IOException { + if (!catalog.tableExists(currentIdentifier)) { + return; + } + + dropTable(currentIdentifier); + } + + @Override + public Table createTable( + TableIdentifier ident, Schema schema, PartitionSpec spec, Map properties) { + TestIcebergSourceHiveTables.currentIdentifier = ident; + return TestIcebergSourceHiveTables.catalog.createTable(ident, schema, spec, properties); + } + + @Override + public void dropTable(TableIdentifier ident) throws IOException { + Table table = catalog.loadTable(ident); + Path tablePath = new Path(table.location()); + FileSystem fs = tablePath.getFileSystem(spark.sessionState().newHadoopConf()); + fs.delete(tablePath, true); + catalog.dropTable(ident, false); + } + + @Override + public Table loadTable(TableIdentifier ident, String entriesSuffix) { + TableIdentifier identifier = + TableIdentifier.of(ident.namespace().level(0), ident.name(), entriesSuffix); + return TestIcebergSourceHiveTables.catalog.loadTable(identifier); + } + + @Override + public String loadLocation(TableIdentifier ident, String entriesSuffix) { + return String.format("%s.%s", loadLocation(ident), entriesSuffix); + } + + @Override + public String loadLocation(TableIdentifier ident) { + return ident.toString(); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceTablesBase.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceTablesBase.java new file mode 100644 index 000000000000..29216150d362 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceTablesBase.java @@ -0,0 +1,2336 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.ManifestContent.DATA; +import static org.apache.iceberg.ManifestContent.DELETES; +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.io.File; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.file.Path; +import java.time.LocalDateTime; +import java.time.ZoneOffset; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.StringJoiner; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericRecordBuilder; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.Files; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.actions.DeleteOrphanFiles; +import org.apache.iceberg.actions.RewriteManifests; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.avro.AvroSchemaUtil; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.data.FileHelpers; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.deletes.PositionDelete; +import org.apache.iceberg.deletes.PositionDeleteWriter; +import org.apache.iceberg.encryption.EncryptedOutputFile; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.io.OutputFileFactory; +import org.apache.iceberg.mapping.MappingUtil; +import org.apache.iceberg.mapping.NameMapping; +import org.apache.iceberg.mapping.NameMappingParser; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkTableUtil; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.spark.TestBase; +import org.apache.iceberg.spark.actions.SparkActions; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.Pair; +import org.apache.spark.SparkException; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.functions; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.types.StructType; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +public abstract class TestIcebergSourceTablesBase extends TestBase { + + private static final Schema SCHEMA = + new Schema( + optional(1, "id", Types.IntegerType.get()), optional(2, "data", Types.StringType.get())); + + private static final Schema SCHEMA2 = + new Schema( + optional(1, "id", Types.IntegerType.get()), + optional(2, "data", Types.StringType.get()), + optional(3, "category", Types.StringType.get())); + + private static final Schema SCHEMA3 = + new Schema( + optional(1, "id", Types.IntegerType.get()), + optional(3, "category", Types.StringType.get())); + + private static final PartitionSpec SPEC = PartitionSpec.builderFor(SCHEMA).identity("id").build(); + + @TempDir protected Path temp; + + public abstract Table createTable( + TableIdentifier ident, Schema schema, PartitionSpec spec, Map properties); + + public abstract Table loadTable(TableIdentifier ident, String entriesSuffix); + + public abstract String loadLocation(TableIdentifier ident, String entriesSuffix); + + public abstract String loadLocation(TableIdentifier ident); + + public abstract void dropTable(TableIdentifier ident) throws IOException; + + @AfterEach + public void removeTable() { + spark.sql("DROP TABLE IF EXISTS parquet_table"); + } + + private Table createTable(TableIdentifier ident, Schema schema, PartitionSpec spec) { + return createTable(ident, schema, spec, ImmutableMap.of()); + } + + @Test + public synchronized void testTablesSupport() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "table"); + createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + + List expectedRecords = + Lists.newArrayList( + new SimpleRecord(1, "1"), new SimpleRecord(2, "2"), new SimpleRecord(3, "3")); + + Dataset inputDf = spark.createDataFrame(expectedRecords, SimpleRecord.class); + inputDf + .select("id", "data") + .write() + .format("iceberg") + .mode(SaveMode.Append) + .save(loadLocation(tableIdentifier)); + + Dataset resultDf = spark.read().format("iceberg").load(loadLocation(tableIdentifier)); + List actualRecords = + resultDf.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + + assertThat(actualRecords).as("Records should match").isEqualTo(expectedRecords); + } + + @Test + public void testEntriesTable() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "entries_test"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + Table entriesTable = loadTable(tableIdentifier, "entries"); + + List records = Lists.newArrayList(new SimpleRecord(1, "1")); + + Dataset inputDf = spark.createDataFrame(records, SimpleRecord.class); + inputDf + .select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + + Dataset entriesTableDs = + spark.read().format("iceberg").load(loadLocation(tableIdentifier, "entries")); + List actual = TestHelpers.selectNonDerived(entriesTableDs).collectAsList(); + + Snapshot snapshot = table.currentSnapshot(); + + assertThat(snapshot.allManifests(table.io())).as("Should only contain one manifest").hasSize(1); + + InputFile manifest = table.io().newInputFile(snapshot.allManifests(table.io()).get(0).path()); + List expected = Lists.newArrayList(); + try (CloseableIterable rows = + Avro.read(manifest).project(entriesTable.schema()).build()) { + // each row must inherit snapshot_id and sequence_number + rows.forEach( + row -> { + row.put(2, 1L); // data sequence number + row.put(3, 1L); // file sequence number + GenericData.Record file = (GenericData.Record) row.get("data_file"); + TestHelpers.asMetadataRecord(file); + expected.add(row); + }); + } + + assertThat(expected).as("Entries table should have one row").hasSize(1); + assertThat(actual).as("Actual results should have one row").hasSize(1); + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(entriesTableDs), expected.get(0), actual.get(0)); + } + + @Test + public void testEntriesTablePartitionedPrune() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "entries_test"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); + + List records = Lists.newArrayList(new SimpleRecord(1, "1")); + + Dataset inputDf = spark.createDataFrame(records, SimpleRecord.class); + inputDf + .select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + + List actual = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "entries")) + .select("status") + .collectAsList(); + + assertThat(actual).as("Results should contain only one status").hasSize(1); + assertThat(actual.get(0).getInt(0)).as("That status should be Added (1)").isEqualTo(1); + } + + @Test + public void testEntriesTableDataFilePrune() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "entries_test"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + + List records = Lists.newArrayList(new SimpleRecord(1, "1")); + + Dataset inputDf = spark.createDataFrame(records, SimpleRecord.class); + inputDf + .select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + DataFile file = table.currentSnapshot().addedDataFiles(table.io()).iterator().next(); + + List singleActual = + rowsToJava( + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "entries")) + .select("data_file.file_path") + .collectAsList()); + + List singleExpected = ImmutableList.of(row(file.path())); + + assertEquals( + "Should prune a single element from a nested struct", singleExpected, singleActual); + } + + @Test + public void testEntriesTableDataFilePruneMulti() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "entries_test"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + + List records = Lists.newArrayList(new SimpleRecord(1, "1")); + + Dataset inputDf = spark.createDataFrame(records, SimpleRecord.class); + inputDf + .select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + DataFile file = table.currentSnapshot().addedDataFiles(table.io()).iterator().next(); + + List multiActual = + rowsToJava( + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "entries")) + .select( + "data_file.file_path", + "data_file.value_counts", + "data_file.record_count", + "data_file.column_sizes") + .collectAsList()); + + List multiExpected = + ImmutableList.of( + row(file.path(), file.valueCounts(), file.recordCount(), file.columnSizes())); + + assertEquals("Should prune a single element from a nested struct", multiExpected, multiActual); + } + + @Test + public void testFilesSelectMap() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "entries_test"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + + List records = Lists.newArrayList(new SimpleRecord(1, "1")); + + Dataset inputDf = spark.createDataFrame(records, SimpleRecord.class); + inputDf + .select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + DataFile file = table.currentSnapshot().addedDataFiles(table.io()).iterator().next(); + + List multiActual = + rowsToJava( + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "files")) + .select("file_path", "value_counts", "record_count", "column_sizes") + .collectAsList()); + + List multiExpected = + ImmutableList.of( + row(file.path(), file.valueCounts(), file.recordCount(), file.columnSizes())); + + assertEquals("Should prune a single element from a row", multiExpected, multiActual); + } + + @Test + public void testAllEntriesTable() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "entries_test"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + Table entriesTable = loadTable(tableIdentifier, "all_entries"); + + Dataset df1 = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "a")), SimpleRecord.class); + Dataset df2 = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "b")), SimpleRecord.class); + + df1.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + // delete the first file to test that not only live files are listed + table.newDelete().deleteFromRowFilter(Expressions.equal("id", 1)).commit(); + + // add a second file + df2.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + // ensure table data isn't stale + table.refresh(); + + Dataset entriesTableDs = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "all_entries")) + .orderBy("snapshot_id"); + List actual = TestHelpers.selectNonDerived(entriesTableDs).collectAsList(); + + List expected = Lists.newArrayList(); + for (ManifestFile manifest : + Iterables.concat(Iterables.transform(table.snapshots(), s -> s.allManifests(table.io())))) { + InputFile in = table.io().newInputFile(manifest.path()); + try (CloseableIterable rows = + Avro.read(in).project(entriesTable.schema()).build()) { + // each row must inherit snapshot_id and sequence_number + rows.forEach( + row -> { + if (row.get("snapshot_id").equals(table.currentSnapshot().snapshotId())) { + row.put(2, 3L); // data sequence number + row.put(3, 3L); // file sequence number + } else { + row.put(2, 1L); // data sequence number + row.put(3, 1L); // file sequence number + } + GenericData.Record file = (GenericData.Record) row.get("data_file"); + TestHelpers.asMetadataRecord(file); + expected.add(row); + }); + } + } + + expected.sort(Comparator.comparing(o -> (Long) o.get("snapshot_id"))); + + assertThat(expected).as("Entries table should have 3 rows").hasSize(3); + assertThat(actual).as("Actual results should have 3 rows").hasSize(3); + + for (int i = 0; i < expected.size(); i += 1) { + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(entriesTableDs), expected.get(i), actual.get(i)); + } + } + + @Test + public void testCountEntriesTable() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "count_entries_test"); + createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + + // init load + List records = Lists.newArrayList(new SimpleRecord(1, "1")); + Dataset inputDf = spark.createDataFrame(records, SimpleRecord.class); + inputDf + .select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + final int expectedEntryCount = 1; + + // count entries + assertThat( + spark.read().format("iceberg").load(loadLocation(tableIdentifier, "entries")).count()) + .as("Count should return " + expectedEntryCount) + .isEqualTo(expectedEntryCount); + + // count all_entries + assertThat( + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "all_entries")) + .count()) + .as("Count should return " + expectedEntryCount) + .isEqualTo(expectedEntryCount); + } + + @Test + public void testFilesTable() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "files_test"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); + Table entriesTable = loadTable(tableIdentifier, "entries"); + + Dataset df1 = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "a")), SimpleRecord.class); + Dataset df2 = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(2, "b")), SimpleRecord.class); + + df1.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + // add a second file + df2.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + // delete the first file to test that only live files are listed + table.newDelete().deleteFromRowFilter(Expressions.equal("id", 1)).commit(); + + Dataset filesTableDs = + spark.read().format("iceberg").load(loadLocation(tableIdentifier, "files")); + List actual = TestHelpers.selectNonDerived(filesTableDs).collectAsList(); + + List expected = Lists.newArrayList(); + for (ManifestFile manifest : table.currentSnapshot().dataManifests(table.io())) { + InputFile in = table.io().newInputFile(manifest.path()); + try (CloseableIterable rows = + Avro.read(in).project(entriesTable.schema()).build()) { + for (GenericData.Record record : rows) { + if ((Integer) record.get("status") < 2 /* added or existing */) { + GenericData.Record file = (GenericData.Record) record.get("data_file"); + TestHelpers.asMetadataRecord(file); + expected.add(file); + } + } + } + } + + assertThat(expected).as("Files table should have one row").hasSize(1); + assertThat(actual).as("Actual results should have one row").hasSize(1); + + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(filesTableDs), expected.get(0), actual.get(0)); + } + + @Test + public void testFilesTableWithSnapshotIdInheritance() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "files_inheritance_test"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); + table.updateProperties().set(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, "true").commit(); + Table entriesTable = loadTable(tableIdentifier, "entries"); + + spark.sql( + String.format( + "CREATE TABLE parquet_table (data string, id int) " + + "USING parquet PARTITIONED BY (id) LOCATION '%s'", + temp.toFile())); + + List records = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + + Dataset inputDF = spark.createDataFrame(records, SimpleRecord.class); + inputDF.select("data", "id").write().mode("overwrite").insertInto("parquet_table"); + + NameMapping mapping = MappingUtil.create(table.schema()); + String mappingJson = NameMappingParser.toJson(mapping); + + table.updateProperties().set(TableProperties.DEFAULT_NAME_MAPPING, mappingJson).commit(); + + String stagingLocation = table.location() + "/metadata"; + SparkTableUtil.importSparkTable( + spark, + new org.apache.spark.sql.catalyst.TableIdentifier("parquet_table"), + table, + stagingLocation); + + Dataset filesTableDs = + spark.read().format("iceberg").load(loadLocation(tableIdentifier, "files")); + List actual = TestHelpers.selectNonDerived(filesTableDs).collectAsList(); + + List expected = Lists.newArrayList(); + for (ManifestFile manifest : table.currentSnapshot().dataManifests(table.io())) { + InputFile in = table.io().newInputFile(manifest.path()); + try (CloseableIterable rows = + Avro.read(in).project(entriesTable.schema()).build()) { + for (GenericData.Record record : rows) { + GenericData.Record file = (GenericData.Record) record.get("data_file"); + TestHelpers.asMetadataRecord(file); + expected.add(file); + } + } + } + + Types.StructType struct = TestHelpers.nonDerivedSchema(filesTableDs); + assertThat(expected).as("Files table should have 2 rows").hasSize(2); + assertThat(actual).as("Actual results should have 2 rows").hasSize(2); + TestHelpers.assertEqualsSafe(struct, expected.get(0), actual.get(0)); + TestHelpers.assertEqualsSafe(struct, expected.get(1), actual.get(1)); + } + + @Test + public void testV1EntriesTableWithSnapshotIdInheritance() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "entries_inheritance_test"); + Map properties = ImmutableMap.of(TableProperties.FORMAT_VERSION, "1"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC, properties); + + table.updateProperties().set(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, "true").commit(); + + spark.sql( + String.format( + "CREATE TABLE parquet_table (data string, id int) " + + "USING parquet PARTITIONED BY (id) LOCATION '%s'", + temp.toFile())); + + List records = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + + Dataset inputDF = spark.createDataFrame(records, SimpleRecord.class); + inputDF.select("data", "id").write().mode("overwrite").insertInto("parquet_table"); + + String stagingLocation = table.location() + "/metadata"; + SparkTableUtil.importSparkTable( + spark, + new org.apache.spark.sql.catalyst.TableIdentifier("parquet_table"), + table, + stagingLocation); + + List actual = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "entries")) + .select("sequence_number", "snapshot_id", "data_file") + .collectAsList(); + + table.refresh(); + + long snapshotId = table.currentSnapshot().snapshotId(); + + assertThat(actual).as("Entries table should have 2 rows").hasSize(2); + assertThat(actual.get(0).getLong(0)).as("Sequence number must match").isEqualTo(0); + assertThat(actual.get(0).getLong(1)).as("Snapshot id must match").isEqualTo(snapshotId); + assertThat(actual.get(1).getLong(0)).as("Sequence number must match").isEqualTo(0); + assertThat(actual.get(1).getLong(1)).as("Snapshot id must match").isEqualTo(snapshotId); + } + + @Test + public void testFilesUnpartitionedTable() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "unpartitioned_files_test"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + Table entriesTable = loadTable(tableIdentifier, "entries"); + + Dataset df1 = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "a")), SimpleRecord.class); + Dataset df2 = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(2, "b")), SimpleRecord.class); + + df1.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + DataFile toDelete = + Iterables.getOnlyElement(table.currentSnapshot().addedDataFiles(table.io())); + + // add a second file + df2.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + // delete the first file to test that only live files are listed + table.newDelete().deleteFile(toDelete).commit(); + + Dataset filesTableDs = + spark.read().format("iceberg").load(loadLocation(tableIdentifier, "files")); + List actual = TestHelpers.selectNonDerived(filesTableDs).collectAsList(); + + List expected = Lists.newArrayList(); + for (ManifestFile manifest : table.currentSnapshot().dataManifests(table.io())) { + InputFile in = table.io().newInputFile(manifest.path()); + try (CloseableIterable rows = + Avro.read(in).project(entriesTable.schema()).build()) { + for (GenericData.Record record : rows) { + if ((Integer) record.get("status") < 2 /* added or existing */) { + GenericData.Record file = (GenericData.Record) record.get("data_file"); + TestHelpers.asMetadataRecord(file); + expected.add(file); + } + } + } + } + + assertThat(expected).as("Files table should have one row").hasSize(1); + assertThat(actual).as("Actual results should have one row").hasSize(1); + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(filesTableDs), expected.get(0), actual.get(0)); + } + + @Test + public void testAllMetadataTablesWithStagedCommits() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "stage_aggregate_table_test"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); + + table.updateProperties().set(TableProperties.WRITE_AUDIT_PUBLISH_ENABLED, "true").commit(); + spark.conf().set(SparkSQLProperties.WAP_ID, "1234567"); + Dataset df1 = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "a")), SimpleRecord.class); + Dataset df2 = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(2, "b")), SimpleRecord.class); + + df1.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + // add a second file + df2.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + List actualAllData = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "all_data_files")) + .collectAsList(); + + List actualAllManifests = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "all_manifests")) + .collectAsList(); + + List actualAllEntries = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "all_entries")) + .collectAsList(); + + assertThat(table.snapshots().iterator()).as("Stage table should have some snapshots").hasNext(); + assertThat(table.currentSnapshot()).as("Stage table should have null currentSnapshot").isNull(); + assertThat(actualAllData).as("Actual results should have two rows").hasSize(2); + assertThat(actualAllManifests).as("Actual results should have two rows").hasSize(2); + assertThat(actualAllEntries).as("Actual results should have two rows").hasSize(2); + } + + @Test + public void testAllDataFilesTable() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "files_test"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); + Table entriesTable = loadTable(tableIdentifier, "entries"); + + Dataset df1 = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "a")), SimpleRecord.class); + Dataset df2 = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(2, "b")), SimpleRecord.class); + + df1.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + // delete the first file to test that not only live files are listed + table.newDelete().deleteFromRowFilter(Expressions.equal("id", 1)).commit(); + + // add a second file + df2.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + // ensure table data isn't stale + table.refresh(); + + Dataset filesTableDs = + spark.read().format("iceberg").load(loadLocation(tableIdentifier, "all_data_files")); + List actual = TestHelpers.selectNonDerived(filesTableDs).collectAsList(); + actual.sort(Comparator.comparing(o -> o.getString(1))); + + List expected = Lists.newArrayList(); + Iterable dataManifests = + Iterables.concat( + Iterables.transform(table.snapshots(), snapshot -> snapshot.dataManifests(table.io()))); + for (ManifestFile manifest : dataManifests) { + InputFile in = table.io().newInputFile(manifest.path()); + try (CloseableIterable rows = + Avro.read(in).project(entriesTable.schema()).build()) { + for (GenericData.Record record : rows) { + if ((Integer) record.get("status") < 2 /* added or existing */) { + GenericData.Record file = (GenericData.Record) record.get("data_file"); + TestHelpers.asMetadataRecord(file); + expected.add(file); + } + } + } + } + + expected.sort(Comparator.comparing(o -> o.get("file_path").toString())); + + assertThat(expected).as("Files table should have two rows").hasSize(2); + assertThat(actual).as("Actual results should have two rows").hasSize(2); + for (int i = 0; i < expected.size(); i += 1) { + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(filesTableDs), expected.get(i), actual.get(i)); + } + } + + @Test + public void testHistoryTable() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "history_test"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + Table historyTable = loadTable(tableIdentifier, "history"); + + List records = Lists.newArrayList(new SimpleRecord(1, "1")); + Dataset inputDf = spark.createDataFrame(records, SimpleRecord.class); + + inputDf + .select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + long firstSnapshotTimestamp = table.currentSnapshot().timestampMillis(); + long firstSnapshotId = table.currentSnapshot().snapshotId(); + + inputDf + .select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + long secondSnapshotTimestamp = table.currentSnapshot().timestampMillis(); + long secondSnapshotId = table.currentSnapshot().snapshotId(); + + // rollback the table state to the first snapshot + table.manageSnapshots().rollbackTo(firstSnapshotId).commit(); + long rollbackTimestamp = Iterables.getLast(table.history()).timestampMillis(); + + inputDf + .select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + long thirdSnapshotTimestamp = table.currentSnapshot().timestampMillis(); + long thirdSnapshotId = table.currentSnapshot().snapshotId(); + + List actual = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "history")) + .collectAsList(); + + GenericRecordBuilder builder = + new GenericRecordBuilder(AvroSchemaUtil.convert(historyTable.schema(), "history")); + List expected = + Lists.newArrayList( + builder + .set("made_current_at", firstSnapshotTimestamp * 1000) + .set("snapshot_id", firstSnapshotId) + .set("parent_id", null) + .set("is_current_ancestor", true) + .build(), + builder + .set("made_current_at", secondSnapshotTimestamp * 1000) + .set("snapshot_id", secondSnapshotId) + .set("parent_id", firstSnapshotId) + .set( + "is_current_ancestor", + false) // commit rolled back, not an ancestor of the current table state + .build(), + builder + .set("made_current_at", rollbackTimestamp * 1000) + .set("snapshot_id", firstSnapshotId) + .set("parent_id", null) + .set("is_current_ancestor", true) + .build(), + builder + .set("made_current_at", thirdSnapshotTimestamp * 1000) + .set("snapshot_id", thirdSnapshotId) + .set("parent_id", firstSnapshotId) + .set("is_current_ancestor", true) + .build()); + + assertThat(actual).as("History table should have a row for each commit").hasSize(4); + TestHelpers.assertEqualsSafe(historyTable.schema().asStruct(), expected.get(0), actual.get(0)); + TestHelpers.assertEqualsSafe(historyTable.schema().asStruct(), expected.get(1), actual.get(1)); + TestHelpers.assertEqualsSafe(historyTable.schema().asStruct(), expected.get(2), actual.get(2)); + TestHelpers.assertEqualsSafe(historyTable.schema().asStruct(), expected.get(3), actual.get(3)); + } + + @Test + public void testSnapshotsTable() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "snapshots_test"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + Table snapTable = loadTable(tableIdentifier, "snapshots"); + + List records = Lists.newArrayList(new SimpleRecord(1, "1")); + Dataset inputDf = spark.createDataFrame(records, SimpleRecord.class); + + inputDf + .select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + long firstSnapshotTimestamp = table.currentSnapshot().timestampMillis(); + long firstSnapshotId = table.currentSnapshot().snapshotId(); + String firstManifestList = table.currentSnapshot().manifestListLocation(); + + table.newDelete().deleteFromRowFilter(Expressions.alwaysTrue()).commit(); + + long secondSnapshotTimestamp = table.currentSnapshot().timestampMillis(); + long secondSnapshotId = table.currentSnapshot().snapshotId(); + String secondManifestList = table.currentSnapshot().manifestListLocation(); + + // rollback the table state to the first snapshot + table.manageSnapshots().rollbackTo(firstSnapshotId).commit(); + + List actual = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "snapshots")) + .collectAsList(); + + GenericRecordBuilder builder = + new GenericRecordBuilder(AvroSchemaUtil.convert(snapTable.schema(), "snapshots")); + List expected = + Lists.newArrayList( + builder + .set("committed_at", firstSnapshotTimestamp * 1000) + .set("snapshot_id", firstSnapshotId) + .set("parent_id", null) + .set("operation", "append") + .set("manifest_list", firstManifestList) + .set( + "summary", + ImmutableMap.of( + "added-records", "1", + "added-data-files", "1", + "changed-partition-count", "1", + "total-data-files", "1", + "total-records", "1")) + .build(), + builder + .set("committed_at", secondSnapshotTimestamp * 1000) + .set("snapshot_id", secondSnapshotId) + .set("parent_id", firstSnapshotId) + .set("operation", "delete") + .set("manifest_list", secondManifestList) + .set( + "summary", + ImmutableMap.of( + "deleted-records", "1", + "deleted-data-files", "1", + "changed-partition-count", "1", + "total-records", "0", + "total-data-files", "0")) + .build()); + + assertThat(actual).as("Snapshots table should have a row for each snapshot").hasSize(2); + TestHelpers.assertEqualsSafe(snapTable.schema().asStruct(), expected.get(0), actual.get(0)); + TestHelpers.assertEqualsSafe(snapTable.schema().asStruct(), expected.get(1), actual.get(1)); + } + + @Test + public void testPrunedSnapshotsTable() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "snapshots_test"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + + List records = Lists.newArrayList(new SimpleRecord(1, "1")); + Dataset inputDf = spark.createDataFrame(records, SimpleRecord.class); + + inputDf + .select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + long firstSnapshotTimestamp = table.currentSnapshot().timestampMillis(); + long firstSnapshotId = table.currentSnapshot().snapshotId(); + + table.newDelete().deleteFromRowFilter(Expressions.alwaysTrue()).commit(); + + long secondSnapshotTimestamp = table.currentSnapshot().timestampMillis(); + + // rollback the table state to the first snapshot + table.manageSnapshots().rollbackTo(firstSnapshotId).commit(); + + Dataset actualDf = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "snapshots")) + .select("operation", "committed_at", "summary", "parent_id"); + + Schema projectedSchema = SparkSchemaUtil.convert(actualDf.schema()); + + List actual = actualDf.collectAsList(); + + GenericRecordBuilder builder = + new GenericRecordBuilder(AvroSchemaUtil.convert(projectedSchema, "snapshots")); + List expected = + Lists.newArrayList( + builder + .set("committed_at", firstSnapshotTimestamp * 1000) + .set("parent_id", null) + .set("operation", "append") + .set( + "summary", + ImmutableMap.of( + "added-records", "1", + "added-data-files", "1", + "changed-partition-count", "1", + "total-data-files", "1", + "total-records", "1")) + .build(), + builder + .set("committed_at", secondSnapshotTimestamp * 1000) + .set("parent_id", firstSnapshotId) + .set("operation", "delete") + .set( + "summary", + ImmutableMap.of( + "deleted-records", "1", + "deleted-data-files", "1", + "changed-partition-count", "1", + "total-records", "0", + "total-data-files", "0")) + .build()); + + assertThat(actual).as("Snapshots table should have a row for each snapshot").hasSize(2); + TestHelpers.assertEqualsSafe(projectedSchema.asStruct(), expected.get(0), actual.get(0)); + TestHelpers.assertEqualsSafe(projectedSchema.asStruct(), expected.get(1), actual.get(1)); + } + + @Test + public void testManifestsTable() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "manifests_test"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); + Table manifestTable = loadTable(tableIdentifier, "manifests"); + Dataset df1 = + spark.createDataFrame( + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(null, "b")), + SimpleRecord.class); + + df1.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .option(SparkWriteOptions.DISTRIBUTION_MODE, TableProperties.WRITE_DISTRIBUTION_MODE_NONE) + .save(loadLocation(tableIdentifier)); + + table.updateProperties().set(TableProperties.FORMAT_VERSION, "2").commit(); + + DeleteFile deleteFile = writePosDeleteFile(table); + + table.newRowDelta().addDeletes(deleteFile).commit(); + + List actual = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "manifests")) + .collectAsList(); + + table.refresh(); + + GenericRecordBuilder builder = + new GenericRecordBuilder(AvroSchemaUtil.convert(manifestTable.schema(), "manifests")); + GenericRecordBuilder summaryBuilder = + new GenericRecordBuilder( + AvroSchemaUtil.convert( + manifestTable.schema().findType("partition_summaries.element").asStructType(), + "partition_summary")); + List expected = + Lists.transform( + table.currentSnapshot().allManifests(table.io()), + manifest -> + builder + .set("content", manifest.content().id()) + .set("path", manifest.path()) + .set("length", manifest.length()) + .set("partition_spec_id", manifest.partitionSpecId()) + .set("added_snapshot_id", manifest.snapshotId()) + .set( + "added_data_files_count", + manifest.content() == DATA ? manifest.addedFilesCount() : 0) + .set( + "existing_data_files_count", + manifest.content() == DATA ? manifest.existingFilesCount() : 0) + .set( + "deleted_data_files_count", + manifest.content() == DATA ? manifest.deletedFilesCount() : 0) + .set( + "added_delete_files_count", + manifest.content() == DELETES ? manifest.addedFilesCount() : 0) + .set( + "existing_delete_files_count", + manifest.content() == DELETES ? manifest.existingFilesCount() : 0) + .set( + "deleted_delete_files_count", + manifest.content() == DELETES ? manifest.deletedFilesCount() : 0) + .set( + "partition_summaries", + Lists.transform( + manifest.partitions(), + partition -> + summaryBuilder + .set("contains_null", manifest.content() == DATA) + .set("contains_nan", false) + .set("lower_bound", "1") + .set("upper_bound", "1") + .build())) + .build()); + + assertThat(actual).as("Manifests table should have two manifest rows").hasSize(2); + TestHelpers.assertEqualsSafe(manifestTable.schema().asStruct(), expected.get(0), actual.get(0)); + TestHelpers.assertEqualsSafe(manifestTable.schema().asStruct(), expected.get(1), actual.get(1)); + } + + @Test + public void testPruneManifestsTable() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "manifests_test"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); + Table manifestTable = loadTable(tableIdentifier, "manifests"); + Dataset df1 = + spark.createDataFrame( + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(null, "b")), + SimpleRecord.class); + + df1.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + if (!spark.version().startsWith("2")) { + // Spark 2 isn't able to actually push down nested struct projections so this will not break + assertThatThrownBy( + () -> + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "manifests")) + .select("partition_spec_id", "path", "partition_summaries.contains_null") + .collectAsList()) + .isInstanceOf(SparkException.class) + .hasMessageContaining("Cannot project a partial list element struct"); + } + + Dataset actualDf = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "manifests")) + .select("partition_spec_id", "path", "partition_summaries"); + + Schema projectedSchema = SparkSchemaUtil.convert(actualDf.schema()); + + List actual = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "manifests")) + .select("partition_spec_id", "path", "partition_summaries") + .collectAsList(); + + table.refresh(); + + GenericRecordBuilder builder = + new GenericRecordBuilder(AvroSchemaUtil.convert(projectedSchema.asStruct())); + GenericRecordBuilder summaryBuilder = + new GenericRecordBuilder( + AvroSchemaUtil.convert( + projectedSchema.findType("partition_summaries.element").asStructType(), + "partition_summary")); + List expected = + Lists.transform( + table.currentSnapshot().allManifests(table.io()), + manifest -> + builder + .set("partition_spec_id", manifest.partitionSpecId()) + .set("path", manifest.path()) + .set( + "partition_summaries", + Lists.transform( + manifest.partitions(), + partition -> + summaryBuilder + .set("contains_null", true) + .set("contains_nan", false) + .set("lower_bound", "1") + .set("upper_bound", "1") + .build())) + .build()); + + assertThat(actual).as("Manifests table should have one manifest row").hasSize(1); + TestHelpers.assertEqualsSafe(projectedSchema.asStruct(), expected.get(0), actual.get(0)); + } + + @Test + public void testAllManifestsTable() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "manifests_test"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); + Table manifestTable = loadTable(tableIdentifier, "all_manifests"); + Dataset df1 = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "a")), SimpleRecord.class); + + df1.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.updateProperties().set(TableProperties.FORMAT_VERSION, "2").commit(); + + DeleteFile deleteFile = writePosDeleteFile(table); + + table.newRowDelta().addDeletes(deleteFile).commit(); + + table.newDelete().deleteFromRowFilter(Expressions.alwaysTrue()).commit(); + + Stream> snapshotIdToManifests = + StreamSupport.stream(table.snapshots().spliterator(), false) + .flatMap( + snapshot -> + snapshot.allManifests(table.io()).stream() + .map(manifest -> Pair.of(snapshot.snapshotId(), manifest))); + + List actual = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "all_manifests")) + .orderBy("path") + .collectAsList(); + + table.refresh(); + + List expected = + snapshotIdToManifests + .map( + snapshotManifest -> + manifestRecord( + manifestTable, snapshotManifest.first(), snapshotManifest.second())) + .sorted(Comparator.comparing(o -> o.get("path").toString())) + .collect(Collectors.toList()); + + assertThat(actual).as("Manifests table should have 5 manifest rows").hasSize(5); + for (int i = 0; i < expected.size(); i += 1) { + TestHelpers.assertEqualsSafe( + manifestTable.schema().asStruct(), expected.get(i), actual.get(i)); + } + } + + @Test + public void testUnpartitionedPartitionsTable() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "unpartitioned_partitions_test"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + + Dataset df = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "a")), SimpleRecord.class); + + df.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + Types.StructType expectedSchema = + Types.StructType.of( + required(2, "record_count", Types.LongType.get(), "Count of records in data files"), + required(3, "file_count", Types.IntegerType.get(), "Count of data files"), + required( + 11, + "total_data_file_size_in_bytes", + Types.LongType.get(), + "Total size in bytes of data files"), + required( + 5, + "position_delete_record_count", + Types.LongType.get(), + "Count of records in position delete files"), + required( + 6, + "position_delete_file_count", + Types.IntegerType.get(), + "Count of position delete files"), + required( + 7, + "equality_delete_record_count", + Types.LongType.get(), + "Count of records in equality delete files"), + required( + 8, + "equality_delete_file_count", + Types.IntegerType.get(), + "Count of equality delete files"), + optional( + 9, + "last_updated_at", + Types.TimestampType.withZone(), + "Commit time of snapshot that last updated this partition"), + optional( + 10, + "last_updated_snapshot_id", + Types.LongType.get(), + "Id of snapshot that last updated this partition")); + + Table partitionsTable = loadTable(tableIdentifier, "partitions"); + + assertThat(expectedSchema) + .as("Schema should not have partition field") + .isEqualTo(partitionsTable.schema().asStruct()); + + GenericRecordBuilder builder = + new GenericRecordBuilder(AvroSchemaUtil.convert(partitionsTable.schema(), "partitions")); + GenericData.Record expectedRow = + builder + .set("last_updated_at", table.currentSnapshot().timestampMillis() * 1000) + .set("last_updated_snapshot_id", table.currentSnapshot().snapshotId()) + .set("record_count", 1L) + .set("file_count", 1) + .set( + "total_data_file_size_in_bytes", + totalSizeInBytes(table.currentSnapshot().addedDataFiles(table.io()))) + .set("position_delete_record_count", 0L) + .set("position_delete_file_count", 0) + .set("equality_delete_record_count", 0L) + .set("equality_delete_file_count", 0) + .build(); + + List actual = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "partitions")) + .collectAsList(); + + assertThat(actual).as("Unpartitioned partitions table should have one row").hasSize(1); + TestHelpers.assertEqualsSafe(expectedSchema, expectedRow, actual.get(0)); + } + + @Test + public void testPartitionsTable() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "partitions_test"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); + Table partitionsTable = loadTable(tableIdentifier, "partitions"); + Dataset df1 = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "a")), SimpleRecord.class); + Dataset df2 = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(2, "b")), SimpleRecord.class); + + df1.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + long firstCommitId = table.currentSnapshot().snapshotId(); + + // add a second file + df2.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + long secondCommitId = table.currentSnapshot().snapshotId(); + + List actual = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "partitions")) + .orderBy("partition.id") + .collectAsList(); + + GenericRecordBuilder builder = + new GenericRecordBuilder(AvroSchemaUtil.convert(partitionsTable.schema(), "partitions")); + GenericRecordBuilder partitionBuilder = + new GenericRecordBuilder( + AvroSchemaUtil.convert( + partitionsTable.schema().findType("partition").asStructType(), "partition")); + List expected = Lists.newArrayList(); + expected.add( + builder + .set("partition", partitionBuilder.set("id", 1).build()) + .set("record_count", 1L) + .set("file_count", 1) + .set( + "total_data_file_size_in_bytes", + totalSizeInBytes(table.snapshot(firstCommitId).addedDataFiles(table.io()))) + .set("position_delete_record_count", 0L) + .set("position_delete_file_count", 0) + .set("equality_delete_record_count", 0L) + .set("equality_delete_file_count", 0) + .set("spec_id", 0) + .set("last_updated_at", table.snapshot(firstCommitId).timestampMillis() * 1000) + .set("last_updated_snapshot_id", firstCommitId) + .build()); + expected.add( + builder + .set("partition", partitionBuilder.set("id", 2).build()) + .set("record_count", 1L) + .set("file_count", 1) + .set( + "total_data_file_size_in_bytes", + totalSizeInBytes(table.snapshot(secondCommitId).addedDataFiles(table.io()))) + .set("position_delete_record_count", 0L) + .set("position_delete_file_count", 0) + .set("equality_delete_record_count", 0L) + .set("equality_delete_file_count", 0) + .set("spec_id", 0) + .set("last_updated_at", table.snapshot(secondCommitId).timestampMillis() * 1000) + .set("last_updated_snapshot_id", secondCommitId) + .build()); + + assertThat(expected).as("Partitions table should have two rows").hasSize(2); + assertThat(actual).as("Actual results should have two rows").hasSize(2); + for (int i = 0; i < 2; i += 1) { + TestHelpers.assertEqualsSafe( + partitionsTable.schema().asStruct(), expected.get(i), actual.get(i)); + } + + // check time travel + List actualAfterFirstCommit = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SNAPSHOT_ID, String.valueOf(firstCommitId)) + .load(loadLocation(tableIdentifier, "partitions")) + .orderBy("partition.id") + .collectAsList(); + + assertThat(actualAfterFirstCommit).as("Actual results should have one row").hasSize(1); + TestHelpers.assertEqualsSafe( + partitionsTable.schema().asStruct(), expected.get(0), actualAfterFirstCommit.get(0)); + + // check predicate push down + List filtered = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "partitions")) + .filter("partition.id < 2") + .collectAsList(); + + assertThat(filtered).as("Actual results should have one row").hasSize(1); + TestHelpers.assertEqualsSafe( + partitionsTable.schema().asStruct(), expected.get(0), filtered.get(0)); + + List nonFiltered = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "partitions")) + .filter("partition.id < 2 or record_count=1") + .collectAsList(); + + assertThat(nonFiltered).as("Actual results should have two rows").hasSize(2); + for (int i = 0; i < 2; i += 1) { + TestHelpers.assertEqualsSafe( + partitionsTable.schema().asStruct(), expected.get(i), actual.get(i)); + } + } + + @Test + public void testPartitionsTableLastUpdatedSnapshot() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "partitions_test"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); + Table partitionsTable = loadTable(tableIdentifier, "partitions"); + Dataset df1 = + spark.createDataFrame( + Lists.newArrayList(new SimpleRecord(1, "1"), new SimpleRecord(2, "2")), + SimpleRecord.class); + Dataset df2 = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(2, "20")), SimpleRecord.class); + + df1.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + long firstCommitId = table.currentSnapshot().snapshotId(); + + // add a second file + df2.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + long secondCommitId = table.currentSnapshot().snapshotId(); + + // check if rewrite manifest does not override metadata about data file's creating snapshot + RewriteManifests.Result rewriteManifestResult = + SparkActions.get().rewriteManifests(table).execute(); + assertThat(rewriteManifestResult.rewrittenManifests()) + .as("rewrite replaced 2 manifests") + .hasSize(2); + + assertThat(rewriteManifestResult.addedManifests()).as("rewrite added 1 manifests").hasSize(1); + + List actual = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "partitions")) + .orderBy("partition.id") + .collectAsList(); + + GenericRecordBuilder builder = + new GenericRecordBuilder(AvroSchemaUtil.convert(partitionsTable.schema(), "partitions")); + GenericRecordBuilder partitionBuilder = + new GenericRecordBuilder( + AvroSchemaUtil.convert( + partitionsTable.schema().findType("partition").asStructType(), "partition")); + + List dataFiles = TestHelpers.dataFiles(table); + assertDataFilePartitions(dataFiles, Arrays.asList(1, 2, 2)); + + List expected = Lists.newArrayList(); + expected.add( + builder + .set("partition", partitionBuilder.set("id", 1).build()) + .set("record_count", 1L) + .set("file_count", 1) + .set("total_data_file_size_in_bytes", dataFiles.get(0).fileSizeInBytes()) + .set("position_delete_record_count", 0L) + .set("position_delete_file_count", 0) + .set("equality_delete_record_count", 0L) + .set("equality_delete_file_count", 0) + .set("spec_id", 0) + .set("last_updated_at", table.snapshot(firstCommitId).timestampMillis() * 1000) + .set("last_updated_snapshot_id", firstCommitId) + .build()); + expected.add( + builder + .set("partition", partitionBuilder.set("id", 2).build()) + .set("record_count", 2L) + .set("file_count", 2) + .set( + "total_data_file_size_in_bytes", + dataFiles.get(1).fileSizeInBytes() + dataFiles.get(2).fileSizeInBytes()) + .set("position_delete_record_count", 0L) + .set("position_delete_file_count", 0) + .set("equality_delete_record_count", 0L) + .set("equality_delete_file_count", 0) + .set("spec_id", 0) + .set("last_updated_at", table.snapshot(secondCommitId).timestampMillis() * 1000) + .set("last_updated_snapshot_id", secondCommitId) + .build()); + + assertThat(expected).as("Partitions table should have two rows").hasSize(2); + assertThat(actual).as("Actual results should have two rows").hasSize(2); + for (int i = 0; i < 2; i += 1) { + TestHelpers.assertEqualsSafe( + partitionsTable.schema().asStruct(), expected.get(i), actual.get(i)); + } + + // check predicate push down + List filtered = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "partitions")) + .filter("partition.id < 2") + .collectAsList(); + assertThat(filtered).as("Actual results should have one row").hasSize(1); + TestHelpers.assertEqualsSafe( + partitionsTable.schema().asStruct(), expected.get(0), filtered.get(0)); + + // check for snapshot expiration + // if snapshot with firstCommitId is expired, + // we expect the partition of id=1 will no longer have last updated timestamp and snapshotId + SparkActions.get().expireSnapshots(table).expireSnapshotId(firstCommitId).execute(); + GenericData.Record newPartitionRecord = + builder + .set("partition", partitionBuilder.set("id", 1).build()) + .set("record_count", 1L) + .set("file_count", 1) + .set("total_data_file_size_in_bytes", dataFiles.get(0).fileSizeInBytes()) + .set("position_delete_record_count", 0L) + .set("position_delete_file_count", 0) + .set("equality_delete_record_count", 0L) + .set("equality_delete_file_count", 0) + .set("spec_id", 0) + .set("last_updated_at", null) + .set("last_updated_snapshot_id", null) + .build(); + expected.remove(0); + expected.add(0, newPartitionRecord); + + List actualAfterSnapshotExpiration = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "partitions")) + .collectAsList(); + assertThat(actualAfterSnapshotExpiration).as("Actual results should have two rows").hasSize(2); + for (int i = 0; i < 2; i += 1) { + TestHelpers.assertEqualsSafe( + partitionsTable.schema().asStruct(), + expected.get(i), + actualAfterSnapshotExpiration.get(i)); + } + } + + @Test + public void testPartitionsTableDeleteStats() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "partitions_test"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); + Table partitionsTable = loadTable(tableIdentifier, "partitions"); + Dataset df1 = + spark.createDataFrame( + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(1, "b"), new SimpleRecord(1, "c")), + SimpleRecord.class); + Dataset df2 = + spark.createDataFrame( + Lists.newArrayList( + new SimpleRecord(2, "d"), new SimpleRecord(2, "e"), new SimpleRecord(2, "f")), + SimpleRecord.class); + + df1.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + long firstCommitId = table.currentSnapshot().snapshotId(); + + // add a second file + df2.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + // test position deletes + table.updateProperties().set(TableProperties.FORMAT_VERSION, "2").commit(); + DeleteFile deleteFile1 = writePosDeleteFile(table, 0); + DeleteFile deleteFile2 = writePosDeleteFile(table, 1); + table.newRowDelta().addDeletes(deleteFile1).addDeletes(deleteFile2).commit(); + table.refresh(); + long posDeleteCommitId = table.currentSnapshot().snapshotId(); + + List actual = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "partitions")) + .orderBy("partition.id") + .collectAsList(); + assertThat(actual).as("Actual results should have two rows").hasSize(2); + + GenericRecordBuilder builder = + new GenericRecordBuilder(AvroSchemaUtil.convert(partitionsTable.schema(), "partitions")); + GenericRecordBuilder partitionBuilder = + new GenericRecordBuilder( + AvroSchemaUtil.convert( + partitionsTable.schema().findType("partition").asStructType(), "partition")); + List expected = Lists.newArrayList(); + expected.add( + builder + .set("partition", partitionBuilder.set("id", 1).build()) + .set("record_count", 3L) + .set("file_count", 1) + .set( + "total_data_file_size_in_bytes", + totalSizeInBytes(table.snapshot(firstCommitId).addedDataFiles(table.io()))) + .set("position_delete_record_count", 0L) + .set("position_delete_file_count", 0) + .set("equality_delete_record_count", 0L) + .set("equality_delete_file_count", 0) + .set("spec_id", 0) + .set("last_updated_at", table.snapshot(firstCommitId).timestampMillis() * 1000) + .set("last_updated_snapshot_id", firstCommitId) + .build()); + expected.add( + builder + .set("partition", partitionBuilder.set("id", 2).build()) + .set("record_count", 3L) + .set("file_count", 1) + .set( + "total_data_file_size_in_bytes", + totalSizeInBytes(table.snapshot(firstCommitId).addedDataFiles(table.io()))) + .set("position_delete_record_count", 2L) // should be incremented now + .set("position_delete_file_count", 2) // should be incremented now + .set("equality_delete_record_count", 0L) + .set("equality_delete_file_count", 0) + .set("spec_id", 0) + .set("last_updated_at", table.snapshot(posDeleteCommitId).timestampMillis() * 1000) + .set("last_updated_snapshot_id", posDeleteCommitId) + .build()); + + for (int i = 0; i < 2; i += 1) { + TestHelpers.assertEqualsSafe( + partitionsTable.schema().asStruct(), expected.get(i), actual.get(i)); + } + + // test equality delete + DeleteFile eqDeleteFile1 = writeEqDeleteFile(table, "d"); + DeleteFile eqDeleteFile2 = writeEqDeleteFile(table, "f"); + table.newRowDelta().addDeletes(eqDeleteFile1).addDeletes(eqDeleteFile2).commit(); + table.refresh(); + long eqDeleteCommitId = table.currentSnapshot().snapshotId(); + actual = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "partitions")) + .orderBy("partition.id") + .collectAsList(); + assertThat(actual).as("Actual results should have two rows").hasSize(2); + expected.remove(0); + expected.add( + 0, + builder + .set("partition", partitionBuilder.set("id", 1).build()) + .set("record_count", 3L) + .set("file_count", 1) + .set("position_delete_record_count", 0L) + .set("position_delete_file_count", 0) + .set("equality_delete_record_count", 2L) // should be incremented now + .set("equality_delete_file_count", 2) // should be incremented now + .set("last_updated_at", table.snapshot(eqDeleteCommitId).timestampMillis() * 1000) + .set("last_updated_snapshot_id", eqDeleteCommitId) + .build()); + for (int i = 0; i < 2; i += 1) { + TestHelpers.assertEqualsSafe( + partitionsTable.schema().asStruct(), expected.get(i), actual.get(i)); + } + } + + @Test + public synchronized void testSnapshotReadAfterAddColumn() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "table"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + + List originalRecords = + Lists.newArrayList( + RowFactory.create(1, "x"), RowFactory.create(2, "y"), RowFactory.create(3, "z")); + + StructType originalSparkSchema = SparkSchemaUtil.convert(SCHEMA); + Dataset inputDf = spark.createDataFrame(originalRecords, originalSparkSchema); + inputDf + .select("id", "data") + .write() + .format("iceberg") + .mode(SaveMode.Append) + .save(loadLocation(tableIdentifier)); + + table.refresh(); + + Dataset resultDf = spark.read().format("iceberg").load(loadLocation(tableIdentifier)); + assertThat(originalRecords) + .as("Records should match") + .isEqualTo(resultDf.orderBy("id").collectAsList()); + + Snapshot snapshotBeforeAddColumn = table.currentSnapshot(); + + table.updateSchema().addColumn("category", Types.StringType.get()).commit(); + + List newRecords = + Lists.newArrayList(RowFactory.create(4, "xy", "B"), RowFactory.create(5, "xyz", "C")); + + StructType newSparkSchema = SparkSchemaUtil.convert(SCHEMA2); + Dataset inputDf2 = spark.createDataFrame(newRecords, newSparkSchema); + inputDf2 + .select("id", "data", "category") + .write() + .format("iceberg") + .mode(SaveMode.Append) + .save(loadLocation(tableIdentifier)); + + table.refresh(); + + List updatedRecords = + Lists.newArrayList( + RowFactory.create(1, "x", null), + RowFactory.create(2, "y", null), + RowFactory.create(3, "z", null), + RowFactory.create(4, "xy", "B"), + RowFactory.create(5, "xyz", "C")); + + Dataset resultDf2 = spark.read().format("iceberg").load(loadLocation(tableIdentifier)); + assertThat(updatedRecords) + .as("Records should match") + .isEqualTo(resultDf2.orderBy("id").collectAsList()); + + Dataset resultDf3 = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SNAPSHOT_ID, snapshotBeforeAddColumn.snapshotId()) + .load(loadLocation(tableIdentifier)); + + assertThat(originalRecords) + .as("Records should match") + .isEqualTo(resultDf3.orderBy("id").collectAsList()); + + assertThat(resultDf3.schema()).as("Schemas should match").isEqualTo(originalSparkSchema); + } + + @Test + public synchronized void testSnapshotReadAfterDropColumn() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "table"); + Table table = createTable(tableIdentifier, SCHEMA2, PartitionSpec.unpartitioned()); + + List originalRecords = + Lists.newArrayList( + RowFactory.create(1, "x", "A"), + RowFactory.create(2, "y", "A"), + RowFactory.create(3, "z", "B")); + + StructType originalSparkSchema = SparkSchemaUtil.convert(SCHEMA2); + Dataset inputDf = spark.createDataFrame(originalRecords, originalSparkSchema); + inputDf + .select("id", "data", "category") + .write() + .format("iceberg") + .mode(SaveMode.Append) + .save(loadLocation(tableIdentifier)); + + table.refresh(); + + Dataset resultDf = spark.read().format("iceberg").load(loadLocation(tableIdentifier)); + + assertThat(resultDf.orderBy("id").collectAsList()) + .as("Records should match") + .isEqualTo(originalRecords); + + long tsBeforeDropColumn = waitUntilAfter(System.currentTimeMillis()); + table.updateSchema().deleteColumn("data").commit(); + long tsAfterDropColumn = waitUntilAfter(System.currentTimeMillis()); + + List newRecords = Lists.newArrayList(RowFactory.create(4, "B"), RowFactory.create(5, "C")); + + StructType newSparkSchema = SparkSchemaUtil.convert(SCHEMA3); + Dataset inputDf2 = spark.createDataFrame(newRecords, newSparkSchema); + inputDf2 + .select("id", "category") + .write() + .format("iceberg") + .mode(SaveMode.Append) + .save(loadLocation(tableIdentifier)); + + table.refresh(); + + List updatedRecords = + Lists.newArrayList( + RowFactory.create(1, "A"), + RowFactory.create(2, "A"), + RowFactory.create(3, "B"), + RowFactory.create(4, "B"), + RowFactory.create(5, "C")); + + Dataset resultDf2 = spark.read().format("iceberg").load(loadLocation(tableIdentifier)); + assertThat(resultDf2.orderBy("id").collectAsList()) + .as("Records should match") + .isEqualTo(updatedRecords); + + Dataset resultDf3 = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.AS_OF_TIMESTAMP, tsBeforeDropColumn) + .load(loadLocation(tableIdentifier)); + + assertThat(resultDf3.orderBy("id").collectAsList()) + .as("Records should match") + .isEqualTo(originalRecords); + + assertThat(resultDf3.schema()).as("Schemas should match").isEqualTo(originalSparkSchema); + + // At tsAfterDropColumn, there has been a schema change, but no new snapshot, + // so the snapshot as of tsAfterDropColumn is the same as that as of tsBeforeDropColumn. + Dataset resultDf4 = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.AS_OF_TIMESTAMP, tsAfterDropColumn) + .load(loadLocation(tableIdentifier)); + + assertThat(resultDf4.orderBy("id").collectAsList()) + .as("Records should match") + .isEqualTo(originalRecords); + + assertThat(resultDf4.schema()).as("Schemas should match").isEqualTo(originalSparkSchema); + } + + @Test + public synchronized void testSnapshotReadAfterAddAndDropColumn() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "table"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + + List originalRecords = + Lists.newArrayList( + RowFactory.create(1, "x"), RowFactory.create(2, "y"), RowFactory.create(3, "z")); + + StructType originalSparkSchema = SparkSchemaUtil.convert(SCHEMA); + Dataset inputDf = spark.createDataFrame(originalRecords, originalSparkSchema); + inputDf + .select("id", "data") + .write() + .format("iceberg") + .mode(SaveMode.Append) + .save(loadLocation(tableIdentifier)); + + table.refresh(); + + Dataset resultDf = spark.read().format("iceberg").load(loadLocation(tableIdentifier)); + + assertThat(resultDf.orderBy("id").collectAsList()) + .as("Records should match") + .isEqualTo(originalRecords); + + Snapshot snapshotBeforeAddColumn = table.currentSnapshot(); + + table.updateSchema().addColumn("category", Types.StringType.get()).commit(); + + List newRecords = + Lists.newArrayList(RowFactory.create(4, "xy", "B"), RowFactory.create(5, "xyz", "C")); + + StructType sparkSchemaAfterAddColumn = SparkSchemaUtil.convert(SCHEMA2); + Dataset inputDf2 = spark.createDataFrame(newRecords, sparkSchemaAfterAddColumn); + inputDf2 + .select("id", "data", "category") + .write() + .format("iceberg") + .mode(SaveMode.Append) + .save(loadLocation(tableIdentifier)); + + table.refresh(); + + List updatedRecords = + Lists.newArrayList( + RowFactory.create(1, "x", null), + RowFactory.create(2, "y", null), + RowFactory.create(3, "z", null), + RowFactory.create(4, "xy", "B"), + RowFactory.create(5, "xyz", "C")); + + Dataset resultDf2 = spark.read().format("iceberg").load(loadLocation(tableIdentifier)); + + assertThat(resultDf2.orderBy("id").collectAsList()) + .as("Records should match") + .isEqualTo(updatedRecords); + + table.updateSchema().deleteColumn("data").commit(); + + List recordsAfterDropColumn = + Lists.newArrayList( + RowFactory.create(1, null), + RowFactory.create(2, null), + RowFactory.create(3, null), + RowFactory.create(4, "B"), + RowFactory.create(5, "C")); + + Dataset resultDf3 = spark.read().format("iceberg").load(loadLocation(tableIdentifier)); + + assertThat(resultDf3.orderBy("id").collectAsList()) + .as("Records should match") + .isEqualTo(recordsAfterDropColumn); + + Dataset resultDf4 = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SNAPSHOT_ID, snapshotBeforeAddColumn.snapshotId()) + .load(loadLocation(tableIdentifier)); + + assertThat(resultDf4.orderBy("id").collectAsList()) + .as("Records should match") + .isEqualTo(originalRecords); + + assertThat(resultDf4.schema()).as("Schemas should match").isEqualTo(originalSparkSchema); + } + + @Test + public void testRemoveOrphanFilesActionSupport() throws InterruptedException { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "table"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + + List records = Lists.newArrayList(new SimpleRecord(1, "1")); + + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + + df.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + df.write().mode("append").parquet(table.location() + "/data"); + + // sleep for 1 second to ensure files will be old enough + Thread.sleep(1000); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result1 = + actions + .deleteOrphanFiles(table) + .location(table.location() + "/metadata") + .olderThan(System.currentTimeMillis()) + .execute(); + + assertThat(result1.orphanFileLocations()).as("Should not delete any metadata files").isEmpty(); + + DeleteOrphanFiles.Result result2 = + actions.deleteOrphanFiles(table).olderThan(System.currentTimeMillis()).execute(); + + assertThat(result2.orphanFileLocations()).as("Should delete 1 data file").hasSize(1); + + Dataset resultDF = spark.read().format("iceberg").load(loadLocation(tableIdentifier)); + List actualRecords = + resultDF.as(Encoders.bean(SimpleRecord.class)).collectAsList(); + + assertThat(actualRecords).as("Rows must match").isEqualTo(records); + } + + @Test + public void testFilesTablePartitionId() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "files_test"); + Table table = + createTable( + tableIdentifier, SCHEMA, PartitionSpec.builderFor(SCHEMA).identity("id").build()); + int spec0 = table.spec().specId(); + + Dataset df1 = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "a")), SimpleRecord.class); + Dataset df2 = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(2, "b")), SimpleRecord.class); + + df1.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + // change partition spec + table.refresh(); + table.updateSpec().removeField("id").commit(); + int spec1 = table.spec().specId(); + + // add a second file + df2.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + List actual = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "files")) + .sort(DataFile.SPEC_ID.name()) + .collectAsList() + .stream() + .map(r -> (Integer) r.getAs(DataFile.SPEC_ID.name())) + .collect(Collectors.toList()); + + assertThat(ImmutableList.of(spec0, spec1)) + .as("Should have two partition specs") + .isEqualTo(actual); + } + + @Test + public void testAllManifestTableSnapshotFiltering() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "all_manifest_snapshot_filtering"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); + Table manifestTable = loadTable(tableIdentifier, "all_manifests"); + Dataset df = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "a")), SimpleRecord.class); + + List> snapshotIdToManifests = Lists.newArrayList(); + + df.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + Snapshot snapshot1 = table.currentSnapshot(); + snapshotIdToManifests.addAll( + snapshot1.allManifests(table.io()).stream() + .map(manifest -> Pair.of(snapshot1.snapshotId(), manifest)) + .collect(Collectors.toList())); + + df.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + Snapshot snapshot2 = table.currentSnapshot(); + assertThat(snapshot2.allManifests(table.io())).as("Should have two manifests").hasSize(2); + snapshotIdToManifests.addAll( + snapshot2.allManifests(table.io()).stream() + .map(manifest -> Pair.of(snapshot2.snapshotId(), manifest)) + .collect(Collectors.toList())); + + // Add manifests that will not be selected + df.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + df.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + StringJoiner snapshotIds = new StringJoiner(",", "(", ")"); + snapshotIds.add(String.valueOf(snapshot1.snapshotId())); + snapshotIds.add(String.valueOf(snapshot2.snapshotId())); + snapshotIds.toString(); + + List actual = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "all_manifests")) + .filter("reference_snapshot_id in " + snapshotIds) + .orderBy("path") + .collectAsList(); + table.refresh(); + + List expected = + snapshotIdToManifests.stream() + .map( + snapshotManifest -> + manifestRecord( + manifestTable, snapshotManifest.first(), snapshotManifest.second())) + .sorted(Comparator.comparing(o -> o.get("path").toString())) + .collect(Collectors.toList()); + + assertThat(actual).as("Manifests table should have 3 manifest rows").hasSize(3); + for (int i = 0; i < expected.size(); i += 1) { + TestHelpers.assertEqualsSafe( + manifestTable.schema().asStruct(), expected.get(i), actual.get(i)); + } + } + + @Test + public void testTableWithInt96Timestamp() throws IOException { + File parquetTableDir = temp.resolve("table_timestamp_int96").toFile(); + String parquetTableLocation = parquetTableDir.toURI().toString(); + Schema schema = + new Schema( + optional(1, "id", Types.LongType.get()), + optional(2, "tmp_col", Types.TimestampType.withZone())); + spark.conf().set(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE().key(), "INT96"); + + LocalDateTime start = LocalDateTime.of(2000, 1, 31, 0, 0, 0); + LocalDateTime end = LocalDateTime.of(2100, 1, 1, 0, 0, 0); + long startSec = start.toEpochSecond(ZoneOffset.UTC); + long endSec = end.toEpochSecond(ZoneOffset.UTC); + Column idColumn = functions.expr("id"); + Column secondsColumn = + functions.expr("(id % " + (endSec - startSec) + " + " + startSec + ")").as("seconds"); + Column timestampColumn = functions.expr("cast( seconds as timestamp) as tmp_col"); + + for (Boolean useDict : new Boolean[] {true, false}) { + for (Boolean useVectorization : new Boolean[] {true, false}) { + spark.sql("DROP TABLE IF EXISTS parquet_table"); + spark + .range(0, 5000, 100, 1) + .select(idColumn, secondsColumn) + .select(idColumn, timestampColumn) + .write() + .format("parquet") + .option("parquet.enable.dictionary", useDict) + .mode("overwrite") + .option("path", parquetTableLocation) + .saveAsTable("parquet_table"); + TableIdentifier tableIdentifier = TableIdentifier.of("db", "table_with_timestamp_int96"); + Table table = createTable(tableIdentifier, schema, PartitionSpec.unpartitioned()); + table + .updateProperties() + .set(TableProperties.PARQUET_VECTORIZATION_ENABLED, useVectorization.toString()) + .commit(); + + String stagingLocation = table.location() + "/metadata"; + SparkTableUtil.importSparkTable( + spark, + new org.apache.spark.sql.catalyst.TableIdentifier("parquet_table"), + table, + stagingLocation); + + // validate we get the expected results back + testWithFilter("tmp_col < to_timestamp('2000-01-31 08:30:00')", tableIdentifier); + testWithFilter("tmp_col <= to_timestamp('2000-01-31 08:30:00')", tableIdentifier); + testWithFilter("tmp_col == to_timestamp('2000-01-31 08:30:00')", tableIdentifier); + testWithFilter("tmp_col > to_timestamp('2000-01-31 08:30:00')", tableIdentifier); + testWithFilter("tmp_col >= to_timestamp('2000-01-31 08:30:00')", tableIdentifier); + dropTable(tableIdentifier); + } + } + } + + private void testWithFilter(String filterExpr, TableIdentifier tableIdentifier) { + List expected = + spark.table("parquet_table").select("tmp_col").filter(filterExpr).collectAsList(); + List actual = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier)) + .select("tmp_col") + .filter(filterExpr) + .collectAsList(); + assertThat(actual).as("Rows must match").containsExactlyInAnyOrderElementsOf(expected); + } + + private GenericData.Record manifestRecord( + Table manifestTable, Long referenceSnapshotId, ManifestFile manifest) { + GenericRecordBuilder builder = + new GenericRecordBuilder(AvroSchemaUtil.convert(manifestTable.schema(), "manifests")); + GenericRecordBuilder summaryBuilder = + new GenericRecordBuilder( + AvroSchemaUtil.convert( + manifestTable.schema().findType("partition_summaries.element").asStructType(), + "partition_summary")); + return builder + .set("content", manifest.content().id()) + .set("path", manifest.path()) + .set("length", manifest.length()) + .set("partition_spec_id", manifest.partitionSpecId()) + .set("added_snapshot_id", manifest.snapshotId()) + .set("added_data_files_count", manifest.content() == DATA ? manifest.addedFilesCount() : 0) + .set( + "existing_data_files_count", + manifest.content() == DATA ? manifest.existingFilesCount() : 0) + .set( + "deleted_data_files_count", + manifest.content() == DATA ? manifest.deletedFilesCount() : 0) + .set( + "added_delete_files_count", + manifest.content() == DELETES ? manifest.addedFilesCount() : 0) + .set( + "existing_delete_files_count", + manifest.content() == DELETES ? manifest.existingFilesCount() : 0) + .set( + "deleted_delete_files_count", + manifest.content() == DELETES ? manifest.deletedFilesCount() : 0) + .set( + "partition_summaries", + Lists.transform( + manifest.partitions(), + partition -> + summaryBuilder + .set("contains_null", false) + .set("contains_nan", false) + .set("lower_bound", "1") + .set("upper_bound", "1") + .build())) + .set("reference_snapshot_id", referenceSnapshotId) + .build(); + } + + private PositionDeleteWriter newPositionDeleteWriter( + Table table, PartitionSpec spec, StructLike partition) { + OutputFileFactory fileFactory = OutputFileFactory.builderFor(table, 0, 0).build(); + EncryptedOutputFile outputFile = fileFactory.newOutputFile(spec, partition); + + SparkFileWriterFactory fileWriterFactory = SparkFileWriterFactory.builderFor(table).build(); + return fileWriterFactory.newPositionDeleteWriter(outputFile, spec, partition); + } + + private DeleteFile writePositionDeletes( + Table table, + PartitionSpec spec, + StructLike partition, + Iterable> deletes) { + PositionDeleteWriter positionDeleteWriter = + newPositionDeleteWriter(table, spec, partition); + + try (PositionDeleteWriter writer = positionDeleteWriter) { + for (PositionDelete delete : deletes) { + writer.write(delete); + } + } catch (IOException e) { + throw new UncheckedIOException(e); + } + + return positionDeleteWriter.toDeleteFile(); + } + + private DeleteFile writePosDeleteFile(Table table) { + return writePosDeleteFile(table, 0L); + } + + private DeleteFile writePosDeleteFile(Table table, long pos) { + DataFile dataFile = + Iterables.getFirst(table.currentSnapshot().addedDataFiles(table.io()), null); + PartitionSpec dataFileSpec = table.specs().get(dataFile.specId()); + StructLike dataFilePartition = dataFile.partition(); + + PositionDelete delete = PositionDelete.create(); + delete.set(dataFile.path(), pos, null); + + return writePositionDeletes(table, dataFileSpec, dataFilePartition, ImmutableList.of(delete)); + } + + private DeleteFile writeEqDeleteFile(Table table, String dataValue) { + List deletes = Lists.newArrayList(); + Schema deleteRowSchema = SCHEMA.select("data"); + Record delete = GenericRecord.create(deleteRowSchema); + deletes.add(delete.copy("data", dataValue)); + try { + return FileHelpers.writeDeleteFile( + table, + Files.localOutput(File.createTempFile("junit", null, temp.toFile())), + org.apache.iceberg.TestHelpers.Row.of(1), + deletes, + deleteRowSchema); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private long totalSizeInBytes(Iterable dataFiles) { + return Lists.newArrayList(dataFiles).stream().mapToLong(DataFile::fileSizeInBytes).sum(); + } + + private void assertDataFilePartitions( + List dataFiles, List expectedPartitionIds) { + assertThat(dataFiles) + .as("Table should have " + expectedPartitionIds.size() + " data files") + .hasSameSizeAs(expectedPartitionIds); + + for (int i = 0; i < dataFiles.size(); ++i) { + assertThat(dataFiles.get(i).partition().get(0, Integer.class).intValue()) + .as("Data file should have partition of id " + expectedPartitionIds.get(i)) + .isEqualTo(expectedPartitionIds.get(i).intValue()); + } + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSpark.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSpark.java new file mode 100644 index 000000000000..7eff93d204e4 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSpark.java @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.sql.Date; +import java.sql.Timestamp; +import java.util.List; +import org.apache.iceberg.spark.IcebergSpark; +import org.apache.iceberg.transforms.Transforms; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.apache.spark.sql.types.CharType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.VarcharType; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +public class TestIcebergSpark { + + private static SparkSession spark = null; + + @BeforeAll + public static void startSpark() { + TestIcebergSpark.spark = SparkSession.builder().master("local[2]").getOrCreate(); + } + + @AfterAll + public static void stopSpark() { + SparkSession currentSpark = TestIcebergSpark.spark; + TestIcebergSpark.spark = null; + currentSpark.stop(); + } + + @Test + public void testRegisterIntegerBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_int_16", DataTypes.IntegerType, 16); + List results = spark.sql("SELECT iceberg_bucket_int_16(1)").collectAsList(); + + assertThat(results).hasSize(1); + assertThat(results.get(0).getInt(0)) + .isEqualTo(Transforms.bucket(16).bind(Types.IntegerType.get()).apply(1)); + } + + @Test + public void testRegisterShortBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_short_16", DataTypes.ShortType, 16); + List results = spark.sql("SELECT iceberg_bucket_short_16(1S)").collectAsList(); + assertThat(results).hasSize(1); + assertThat(results.get(0).getInt(0)) + .isEqualTo(Transforms.bucket(16).bind(Types.IntegerType.get()).apply(1)); + } + + @Test + public void testRegisterByteBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_byte_16", DataTypes.ByteType, 16); + List results = spark.sql("SELECT iceberg_bucket_byte_16(1Y)").collectAsList(); + assertThat(results).hasSize(1); + assertThat(results.get(0).getInt(0)) + .isEqualTo(Transforms.bucket(16).bind(Types.IntegerType.get()).apply(1)); + } + + @Test + public void testRegisterLongBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_long_16", DataTypes.LongType, 16); + List results = spark.sql("SELECT iceberg_bucket_long_16(1L)").collectAsList(); + assertThat(results).hasSize(1); + assertThat(results.get(0).getInt(0)) + .isEqualTo(Transforms.bucket(16).bind(Types.LongType.get()).apply(1L)); + } + + @Test + public void testRegisterStringBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_string_16", DataTypes.StringType, 16); + List results = spark.sql("SELECT iceberg_bucket_string_16('hello')").collectAsList(); + assertThat(results).hasSize(1); + assertThat(results.get(0).getInt(0)) + .isEqualTo(Transforms.bucket(16).bind(Types.StringType.get()).apply("hello")); + } + + @Test + public void testRegisterCharBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_char_16", new CharType(5), 16); + List results = spark.sql("SELECT iceberg_bucket_char_16('hello')").collectAsList(); + assertThat(results).hasSize(1); + assertThat(results.get(0).getInt(0)) + .isEqualTo(Transforms.bucket(16).bind(Types.StringType.get()).apply("hello")); + } + + @Test + public void testRegisterVarCharBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_varchar_16", new VarcharType(5), 16); + List results = spark.sql("SELECT iceberg_bucket_varchar_16('hello')").collectAsList(); + assertThat(results).hasSize(1); + assertThat(results.get(0).getInt(0)) + .isEqualTo(Transforms.bucket(16).bind(Types.StringType.get()).apply("hello")); + } + + @Test + public void testRegisterDateBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_date_16", DataTypes.DateType, 16); + List results = + spark.sql("SELECT iceberg_bucket_date_16(DATE '2021-06-30')").collectAsList(); + assertThat(results).hasSize(1); + assertThat(results.get(0).getInt(0)) + .isEqualTo( + Transforms.bucket(16) + .bind(Types.DateType.get()) + .apply(DateTimeUtils.fromJavaDate(Date.valueOf("2021-06-30")))); + } + + @Test + public void testRegisterTimestampBucketUDF() { + IcebergSpark.registerBucketUDF( + spark, "iceberg_bucket_timestamp_16", DataTypes.TimestampType, 16); + List results = + spark + .sql("SELECT iceberg_bucket_timestamp_16(TIMESTAMP '2021-06-30 00:00:00.000')") + .collectAsList(); + assertThat(results).hasSize(1); + assertThat(results.get(0).getInt(0)) + .isEqualTo( + Transforms.bucket(16) + .bind(Types.TimestampType.withZone()) + .apply( + DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2021-06-30 00:00:00.000")))); + } + + @Test + public void testRegisterBinaryBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_binary_16", DataTypes.BinaryType, 16); + List results = spark.sql("SELECT iceberg_bucket_binary_16(X'0020001F')").collectAsList(); + assertThat(results).hasSize(1); + assertThat(results.get(0).getInt(0)) + .isEqualTo( + Transforms.bucket(16) + .bind(Types.BinaryType.get()) + .apply(ByteBuffer.wrap(new byte[] {0x00, 0x20, 0x00, 0x1F}))); + } + + @Test + public void testRegisterDecimalBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_decimal_16", new DecimalType(4, 2), 16); + List results = spark.sql("SELECT iceberg_bucket_decimal_16(11.11)").collectAsList(); + assertThat(results).hasSize(1); + assertThat(results.get(0).getInt(0)) + .isEqualTo( + Transforms.bucket(16).bind(Types.DecimalType.of(4, 2)).apply(new BigDecimal("11.11"))); + } + + @Test + public void testRegisterBooleanBucketUDF() { + assertThatThrownBy( + () -> + IcebergSpark.registerBucketUDF( + spark, "iceberg_bucket_boolean_16", DataTypes.BooleanType, 16)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot bucket by type: boolean"); + } + + @Test + public void testRegisterDoubleBucketUDF() { + assertThatThrownBy( + () -> + IcebergSpark.registerBucketUDF( + spark, "iceberg_bucket_double_16", DataTypes.DoubleType, 16)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot bucket by type: double"); + } + + @Test + public void testRegisterFloatBucketUDF() { + assertThatThrownBy( + () -> + IcebergSpark.registerBucketUDF( + spark, "iceberg_bucket_float_16", DataTypes.FloatType, 16)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot bucket by type: float"); + } + + @Test + public void testRegisterIntegerTruncateUDF() { + IcebergSpark.registerTruncateUDF(spark, "iceberg_truncate_int_4", DataTypes.IntegerType, 4); + List results = spark.sql("SELECT iceberg_truncate_int_4(1)").collectAsList(); + assertThat(results).hasSize(1); + assertThat(results.get(0).getInt(0)) + .isEqualTo(Transforms.truncate(4).bind(Types.IntegerType.get()).apply(1)); + } + + @Test + public void testRegisterLongTruncateUDF() { + IcebergSpark.registerTruncateUDF(spark, "iceberg_truncate_long_4", DataTypes.LongType, 4); + List results = spark.sql("SELECT iceberg_truncate_long_4(1L)").collectAsList(); + assertThat(results).hasSize(1); + assertThat(results.get(0).getLong(0)) + .isEqualTo(Transforms.truncate(4).bind(Types.LongType.get()).apply(1L)); + } + + @Test + public void testRegisterDecimalTruncateUDF() { + IcebergSpark.registerTruncateUDF(spark, "iceberg_truncate_decimal_4", new DecimalType(4, 2), 4); + List results = spark.sql("SELECT iceberg_truncate_decimal_4(11.11)").collectAsList(); + assertThat(results).hasSize(1); + assertThat(results.get(0).getDecimal(0)) + .isEqualTo( + Transforms.truncate(4).bind(Types.DecimalType.of(4, 2)).apply(new BigDecimal("11.11"))); + } + + @Test + public void testRegisterStringTruncateUDF() { + IcebergSpark.registerTruncateUDF(spark, "iceberg_truncate_string_4", DataTypes.StringType, 4); + List results = spark.sql("SELECT iceberg_truncate_string_4('hello')").collectAsList(); + assertThat(results).hasSize(1); + assertThat(results.get(0).getString(0)) + .isEqualTo(Transforms.truncate(4).bind(Types.StringType.get()).apply("hello")); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData.java new file mode 100644 index 000000000000..35a675029c1c --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData.java @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.PlanningMode.DISTRIBUTED; +import static org.apache.iceberg.PlanningMode.LOCAL; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.File; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkTableUtil; +import org.apache.iceberg.spark.TestBase; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.TableIdentifier; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestIdentityPartitionData extends TestBase { + private static final Configuration CONF = new Configuration(); + private static final HadoopTables TABLES = new HadoopTables(CONF); + + @Parameters(name = "format = {0}, vectorized = {1}, properties = {2}") + public static Object[][] parameters() { + return new Object[][] { + { + "parquet", + false, + ImmutableMap.of( + TableProperties.DEFAULT_FILE_FORMAT, "parquet", + TableProperties.DATA_PLANNING_MODE, LOCAL.modeName(), + TableProperties.DELETE_PLANNING_MODE, LOCAL.modeName()) + }, + { + "parquet", + true, + ImmutableMap.of( + TableProperties.DEFAULT_FILE_FORMAT, "parquet", + TableProperties.DATA_PLANNING_MODE, DISTRIBUTED.modeName(), + TableProperties.DELETE_PLANNING_MODE, DISTRIBUTED.modeName()) + }, + { + "avro", + false, + ImmutableMap.of( + TableProperties.DEFAULT_FILE_FORMAT, "avro", + TableProperties.DATA_PLANNING_MODE, LOCAL.modeName(), + TableProperties.DELETE_PLANNING_MODE, LOCAL.modeName()) + }, + { + "orc", + false, + ImmutableMap.of( + TableProperties.DEFAULT_FILE_FORMAT, "orc", + TableProperties.DATA_PLANNING_MODE, DISTRIBUTED.modeName(), + TableProperties.DELETE_PLANNING_MODE, DISTRIBUTED.modeName()) + }, + { + "orc", + true, + ImmutableMap.of( + TableProperties.DEFAULT_FILE_FORMAT, "orc", + TableProperties.DATA_PLANNING_MODE, LOCAL.modeName(), + TableProperties.DELETE_PLANNING_MODE, LOCAL.modeName()) + }, + }; + } + + @Parameter(index = 0) + private String format; + + @Parameter(index = 1) + private boolean vectorized; + + @Parameter(index = 2) + private Map properties; + + private static final Schema LOG_SCHEMA = + new Schema( + Types.NestedField.optional(1, "id", Types.IntegerType.get()), + Types.NestedField.optional(2, "date", Types.StringType.get()), + Types.NestedField.optional(3, "level", Types.StringType.get()), + Types.NestedField.optional(4, "message", Types.StringType.get())); + + private static final List LOGS = + ImmutableList.of( + LogMessage.debug("2020-02-02", "debug event 1"), + LogMessage.info("2020-02-02", "info event 1"), + LogMessage.debug("2020-02-02", "debug event 2"), + LogMessage.info("2020-02-03", "info event 2"), + LogMessage.debug("2020-02-03", "debug event 3"), + LogMessage.info("2020-02-03", "info event 3"), + LogMessage.error("2020-02-03", "error event 1"), + LogMessage.debug("2020-02-04", "debug event 4"), + LogMessage.warn("2020-02-04", "warn event 1"), + LogMessage.debug("2020-02-04", "debug event 5")); + + @TempDir private Path temp; + + private final PartitionSpec spec = + PartitionSpec.builderFor(LOG_SCHEMA).identity("date").identity("level").build(); + private Table table = null; + private Dataset logs = null; + + /** + * Use the Hive Based table to make Identity Partition Columns with no duplication of the data in + * the underlying parquet files. This makes sure that if the identity mapping fails, the test will + * also fail. + */ + private void setupParquet() throws Exception { + File location = Files.createTempDirectory(temp, "logs").toFile(); + File hiveLocation = Files.createTempDirectory(temp, "hive").toFile(); + String hiveTable = "hivetable"; + assertThat(location).as("Temp folder should exist").exists(); + + this.logs = + spark.createDataFrame(LOGS, LogMessage.class).select("id", "date", "level", "message"); + spark.sql(String.format("DROP TABLE IF EXISTS %s", hiveTable)); + logs.orderBy("date", "level", "id") + .write() + .partitionBy("date", "level") + .format("parquet") + .option("path", hiveLocation.toString()) + .saveAsTable(hiveTable); + + this.table = + TABLES.create( + SparkSchemaUtil.schemaForTable(spark, hiveTable), + SparkSchemaUtil.specForTable(spark, hiveTable), + properties, + location.toString()); + + SparkTableUtil.importSparkTable( + spark, new TableIdentifier(hiveTable), table, location.toString()); + } + + @BeforeEach + public void setupTable() throws Exception { + if (format.equals("parquet")) { + setupParquet(); + } else { + File location = Files.createTempDirectory(temp, "logs").toFile(); + assertThat(location).as("Temp folder should exist").exists(); + + this.table = TABLES.create(LOG_SCHEMA, spec, properties, location.toString()); + this.logs = + spark.createDataFrame(LOGS, LogMessage.class).select("id", "date", "level", "message"); + + logs.orderBy("date", "level", "id") + .write() + .format("iceberg") + .mode("append") + .save(location.toString()); + } + } + + @TestTemplate + public void testFullProjection() { + List expected = logs.orderBy("id").collectAsList(); + List actual = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(table.location()) + .orderBy("id") + .select("id", "date", "level", "message") + .collectAsList(); + assertThat(actual).as("Rows should match").isEqualTo(expected); + } + + @TestTemplate + public void testProjections() { + String[][] cases = + new String[][] { + // individual fields + new String[] {"date"}, + new String[] {"level"}, + new String[] {"message"}, + // field pairs + new String[] {"date", "message"}, + new String[] {"level", "message"}, + new String[] {"date", "level"}, + // out-of-order pairs + new String[] {"message", "date"}, + new String[] {"message", "level"}, + new String[] {"level", "date"}, + // full projection, different orderings + new String[] {"date", "level", "message"}, + new String[] {"level", "date", "message"}, + new String[] {"date", "message", "level"}, + new String[] {"level", "message", "date"}, + new String[] {"message", "date", "level"}, + new String[] {"message", "level", "date"} + }; + + for (String[] ordering : cases) { + List expected = logs.select("id", ordering).orderBy("id").collectAsList(); + List actual = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(table.location()) + .select("id", ordering) + .orderBy("id") + .collectAsList(); + assertThat(actual) + .as("Rows should match for ordering: " + Arrays.toString(ordering)) + .isEqualTo(expected); + } + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestInternalRowWrapper.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestInternalRowWrapper.java new file mode 100644 index 000000000000..0c869aa8e7e0 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestInternalRowWrapper.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.Iterator; +import org.apache.iceberg.RecordWrapperTest; +import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.data.InternalRecordWrapper; +import org.apache.iceberg.data.RandomGenericData; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.util.StructLikeWrapper; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.jupiter.api.Disabled; + +public class TestInternalRowWrapper extends RecordWrapperTest { + + @Disabled + @Override + public void testTimestampWithoutZone() { + // Spark does not support timestamp without zone. + } + + @Disabled + @Override + public void testTime() { + // Spark does not support time fields. + } + + @Override + protected void generateAndValidate(Schema schema, AssertMethod assertMethod) { + int numRecords = 100; + Iterable recordList = RandomGenericData.generate(schema, numRecords, 101L); + Iterable rowList = RandomData.generateSpark(schema, numRecords, 101L); + + InternalRecordWrapper recordWrapper = new InternalRecordWrapper(schema.asStruct()); + InternalRowWrapper rowWrapper = + new InternalRowWrapper(SparkSchemaUtil.convert(schema), schema.asStruct()); + + Iterator actual = recordList.iterator(); + Iterator expected = rowList.iterator(); + + StructLikeWrapper actualWrapper = StructLikeWrapper.forType(schema.asStruct()); + StructLikeWrapper expectedWrapper = StructLikeWrapper.forType(schema.asStruct()); + for (int i = 0; i < numRecords; i++) { + assertThat(actual).as("Should have more records").hasNext(); + assertThat(expected).as("Should have more InternalRow").hasNext(); + + StructLike recordStructLike = recordWrapper.wrap(actual.next()); + StructLike rowStructLike = rowWrapper.wrap(expected.next()); + + assertMethod.assertEquals( + "Should have expected StructLike values", + actualWrapper.set(recordStructLike), + expectedWrapper.set(rowStructLike)); + } + + assertThat(actual).as("Shouldn't have more record").isExhausted(); + assertThat(expected).as("Shouldn't have more InternalRow").isExhausted(); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestMetadataTableReadableMetrics.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestMetadataTableReadableMetrics.java new file mode 100644 index 000000000000..547ab32eac24 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestMetadataTableReadableMetrics.java @@ -0,0 +1,380 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.IOException; +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.nio.file.Path; +import java.util.Base64; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.Files; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.data.FileHelpers; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.TestBaseWithCatalog; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.Pair; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.io.TempDir; + +public class TestMetadataTableReadableMetrics extends TestBaseWithCatalog { + + @TempDir private Path temp; + + private static final Types.StructType LEAF_STRUCT_TYPE = + Types.StructType.of( + optional(1, "leafLongCol", Types.LongType.get()), + optional(2, "leafDoubleCol", Types.DoubleType.get())); + + private static final Types.StructType NESTED_STRUCT_TYPE = + Types.StructType.of(required(3, "leafStructCol", LEAF_STRUCT_TYPE)); + + private static final Schema NESTED_SCHEMA = + new Schema(required(4, "nestedStructCol", NESTED_STRUCT_TYPE)); + + private static final Schema PRIMITIVE_SCHEMA = + new Schema( + required(1, "booleanCol", Types.BooleanType.get()), + required(2, "intCol", Types.IntegerType.get()), + required(3, "longCol", Types.LongType.get()), + required(4, "floatCol", Types.FloatType.get()), + required(5, "doubleCol", Types.DoubleType.get()), + optional(6, "decimalCol", Types.DecimalType.of(10, 2)), + optional(7, "stringCol", Types.StringType.get()), + optional(8, "fixedCol", Types.FixedType.ofLength(3)), + optional(9, "binaryCol", Types.BinaryType.get())); + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + protected static Object[][] parameters() { + return new Object[][] { + { + // only SparkCatalog supports metadata table sql queries + SparkCatalogConfig.HIVE.catalogName(), + SparkCatalogConfig.HIVE.implementation(), + SparkCatalogConfig.HIVE.properties() + }, + }; + } + + protected String tableName() { + return tableName.split("\\.")[2]; + } + + protected String database() { + return tableName.split("\\.")[1]; + } + + private Table createPrimitiveTable() throws IOException { + Table table = + catalog.createTable( + TableIdentifier.of(Namespace.of(database()), tableName()), + PRIMITIVE_SCHEMA, + PartitionSpec.unpartitioned(), + ImmutableMap.of()); + List records = + Lists.newArrayList( + createPrimitiveRecord( + false, + 1, + 1L, + 0, + 1.0D, + new BigDecimal("1.00"), + "1", + Base64.getDecoder().decode("1111"), + ByteBuffer.wrap(Base64.getDecoder().decode("1111"))), + createPrimitiveRecord( + true, + 2, + 2L, + 0, + 2.0D, + new BigDecimal("2.00"), + "2", + Base64.getDecoder().decode("2222"), + ByteBuffer.wrap(Base64.getDecoder().decode("2222"))), + createPrimitiveRecord(false, 1, 1, Float.NaN, Double.NaN, null, "1", null, null), + createPrimitiveRecord( + false, 2, 2L, Float.NaN, 2.0D, new BigDecimal("2.00"), "2", null, null)); + + DataFile dataFile = FileHelpers.writeDataFile(table, Files.localOutput(temp.toFile()), records); + table.newAppend().appendFile(dataFile).commit(); + return table; + } + + private Pair createNestedTable() throws IOException { + Table table = + catalog.createTable( + TableIdentifier.of(Namespace.of(database()), tableName()), + NESTED_SCHEMA, + PartitionSpec.unpartitioned(), + ImmutableMap.of()); + + List records = + Lists.newArrayList( + createNestedRecord(0L, 0.0), + createNestedRecord(1L, Double.NaN), + createNestedRecord(null, null)); + DataFile dataFile = FileHelpers.writeDataFile(table, Files.localOutput(temp.toFile()), records); + table.newAppend().appendFile(dataFile).commit(); + return Pair.of(table, dataFile); + } + + @AfterEach + public void dropTable() { + sql("DROP TABLE %s", tableName); + } + + private Dataset filesDf() { + return spark.read().format("iceberg").load(database() + "." + tableName() + ".files"); + } + + protected GenericRecord createPrimitiveRecord( + boolean booleanCol, + int intCol, + long longCol, + float floatCol, + double doubleCol, + BigDecimal decimalCol, + String stringCol, + byte[] fixedCol, + ByteBuffer binaryCol) { + GenericRecord record = GenericRecord.create(PRIMITIVE_SCHEMA); + record.set(0, booleanCol); + record.set(1, intCol); + record.set(2, longCol); + record.set(3, floatCol); + record.set(4, doubleCol); + record.set(5, decimalCol); + record.set(6, stringCol); + record.set(7, fixedCol); + record.set(8, binaryCol); + return record; + } + + private GenericRecord createNestedRecord(Long longCol, Double doubleCol) { + GenericRecord record = GenericRecord.create(NESTED_SCHEMA); + GenericRecord nested = GenericRecord.create(NESTED_STRUCT_TYPE); + GenericRecord leaf = GenericRecord.create(LEAF_STRUCT_TYPE); + leaf.set(0, longCol); + leaf.set(1, doubleCol); + nested.set(0, leaf); + record.set(0, nested); + return record; + } + + @TestTemplate + public void testPrimitiveColumns() throws Exception { + Table table = createPrimitiveTable(); + DataFile dataFile = table.currentSnapshot().addedDataFiles(table.io()).iterator().next(); + Map columnSizeStats = dataFile.columnSizes(); + + Object[] binaryCol = + row( + columnSizeStats.get(PRIMITIVE_SCHEMA.findField("binaryCol").fieldId()), + 4L, + 2L, + null, + Base64.getDecoder().decode("1111"), + Base64.getDecoder().decode("2222")); + Object[] booleanCol = + row( + columnSizeStats.get(PRIMITIVE_SCHEMA.findField("booleanCol").fieldId()), + 4L, + 0L, + null, + false, + true); + Object[] decimalCol = + row( + columnSizeStats.get(PRIMITIVE_SCHEMA.findField("decimalCol").fieldId()), + 4L, + 1L, + null, + new BigDecimal("1.00"), + new BigDecimal("2.00")); + Object[] doubleCol = + row( + columnSizeStats.get(PRIMITIVE_SCHEMA.findField("doubleCol").fieldId()), + 4L, + 0L, + 1L, + 1.0D, + 2.0D); + Object[] fixedCol = + row( + columnSizeStats.get(PRIMITIVE_SCHEMA.findField("fixedCol").fieldId()), + 4L, + 2L, + null, + Base64.getDecoder().decode("1111"), + Base64.getDecoder().decode("2222")); + Object[] floatCol = + row( + columnSizeStats.get(PRIMITIVE_SCHEMA.findField("floatCol").fieldId()), + 4L, + 0L, + 2L, + 0f, + 0f); + Object[] intCol = + row( + columnSizeStats.get(PRIMITIVE_SCHEMA.findField("intCol").fieldId()), + 4L, + 0L, + null, + 1, + 2); + Object[] longCol = + row( + columnSizeStats.get(PRIMITIVE_SCHEMA.findField("longCol").fieldId()), + 4L, + 0L, + null, + 1L, + 2L); + Object[] stringCol = + row( + columnSizeStats.get(PRIMITIVE_SCHEMA.findField("stringCol").fieldId()), + 4L, + 0L, + null, + "1", + "2"); + + Object[] metrics = + row( + binaryCol, + booleanCol, + decimalCol, + doubleCol, + fixedCol, + floatCol, + intCol, + longCol, + stringCol); + + List expected = ImmutableList.of(new Object[] {metrics}); + String sql = "SELECT readable_metrics FROM %s.%s"; + List filesReadableMetrics = sql(String.format(sql, tableName, "files")); + List entriesReadableMetrics = sql(String.format(sql, tableName, "entries")); + assertEquals("Row should match for files table", expected, filesReadableMetrics); + assertEquals("Row should match for entries table", expected, entriesReadableMetrics); + } + + @TestTemplate + public void testSelectPrimitiveValues() throws Exception { + createPrimitiveTable(); + + List expected = ImmutableList.of(row(1, true)); + String sql = + "SELECT readable_metrics.intCol.lower_bound, readable_metrics.booleanCol.upper_bound FROM %s.%s"; + List filesReadableMetrics = sql(String.format(sql, tableName, "files")); + List entriesReadableMetrics = sql(String.format(sql, tableName, "entries")); + assertEquals( + "select of primitive readable_metrics fields should work for files table", + expected, + filesReadableMetrics); + assertEquals( + "select of primitive readable_metrics fields should work for entries table", + expected, + entriesReadableMetrics); + + assertEquals( + "mixed select of readable_metrics and other field should work", + ImmutableList.of(row(0, 4L)), + sql("SELECT content, readable_metrics.longCol.value_count FROM %s.files", tableName)); + + assertEquals( + "mixed select of readable_metrics and other field should work, in the other order", + ImmutableList.of(row(4L, 0)), + sql("SELECT readable_metrics.longCol.value_count, content FROM %s.files", tableName)); + + assertEquals( + "mixed select of readable_metrics and other field should work for entries table", + ImmutableList.of(row(1, 4L)), + sql("SELECT status, readable_metrics.longCol.value_count FROM %s.entries", tableName)); + + assertEquals( + "mixed select of readable_metrics and other field should work, in the other order for entries table", + ImmutableList.of(row(4L, 1)), + sql("SELECT readable_metrics.longCol.value_count, status FROM %s.entries", tableName)); + } + + @TestTemplate + public void testSelectNestedValues() throws Exception { + createNestedTable(); + + List expected = ImmutableList.of(row(0L, 3L)); + String sql = + "SELECT readable_metrics.`nestedStructCol.leafStructCol.leafLongCol`.lower_bound, " + + "readable_metrics.`nestedStructCol.leafStructCol.leafDoubleCol`.value_count FROM %s.%s"; + List filesReadableMetrics = sql(String.format(sql, tableName, "files")); + List entriesReadableMetrics = sql(String.format(sql, tableName, "entries")); + + assertEquals( + "select of nested readable_metrics fields should work for files table", + expected, + filesReadableMetrics); + assertEquals( + "select of nested readable_metrics fields should work for entries table", + expected, + entriesReadableMetrics); + } + + @TestTemplate + public void testNestedValues() throws Exception { + Pair table = createNestedTable(); + int longColId = + table.first().schema().findField("nestedStructCol.leafStructCol.leafLongCol").fieldId(); + int doubleColId = + table.first().schema().findField("nestedStructCol.leafStructCol.leafDoubleCol").fieldId(); + + Object[] leafDoubleCol = + row(table.second().columnSizes().get(doubleColId), 3L, 1L, 1L, 0.0D, 0.0D); + Object[] leafLongCol = row(table.second().columnSizes().get(longColId), 3L, 1L, null, 0L, 1L); + Object[] metrics = row(leafDoubleCol, leafLongCol); + + List expected = ImmutableList.of(new Object[] {metrics}); + String sql = "SELECT readable_metrics FROM %s.%s"; + List filesReadableMetrics = sql(String.format(sql, tableName, "files")); + List entriesReadableMetrics = sql(String.format(sql, tableName, "entries")); + assertEquals("Row should match for files table", expected, filesReadableMetrics); + assertEquals("Row should match for entries table", expected, entriesReadableMetrics); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestMetadataTablesWithPartitionEvolution.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestMetadataTablesWithPartitionEvolution.java new file mode 100644 index 000000000000..a417454b45dc --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestMetadataTablesWithPartitionEvolution.java @@ -0,0 +1,725 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.FileFormat.AVRO; +import static org.apache.iceberg.FileFormat.ORC; +import static org.apache.iceberg.FileFormat.PARQUET; +import static org.apache.iceberg.MetadataTableType.ALL_DATA_FILES; +import static org.apache.iceberg.MetadataTableType.ALL_ENTRIES; +import static org.apache.iceberg.MetadataTableType.ENTRIES; +import static org.apache.iceberg.MetadataTableType.FILES; +import static org.apache.iceberg.MetadataTableType.PARTITIONS; +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.FORMAT_VERSION; +import static org.apache.iceberg.TableProperties.MANIFEST_MERGE_ENABLED; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assumptions.assumeThat; + +import java.util.Arrays; +import java.util.List; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.MetadataTableType; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.TestHelpers; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.CatalogTestBase; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.StructType; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestMetadataTablesWithPartitionEvolution extends CatalogTestBase { + + @Parameters(name = "catalog = {0}, impl = {1}, conf = {2}, fileFormat = {3}, formatVersion = {4}") + public static Object[][] parameters() { + return new Object[][] { + { + "testhive", + SparkCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default"), + ORC, + 1 + }, + { + "testhive", + SparkCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default"), + ORC, + 2 + }, + {"testhadoop", SparkCatalog.class.getName(), ImmutableMap.of("type", "hadoop"), PARQUET, 1}, + {"testhadoop", SparkCatalog.class.getName(), ImmutableMap.of("type", "hadoop"), PARQUET, 2}, + { + "spark_catalog", + SparkSessionCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "clients", "1", + "parquet-enabled", "false", + "cache-enabled", + "false" // Spark will delete tables using v1, leaving the cache out of sync + ), + AVRO, + 1 + }, + { + "spark_catalog", + SparkSessionCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "clients", "1", + "parquet-enabled", "false", + "cache-enabled", + "false" // Spark will delete tables using v1, leaving the cache out of sync + ), + AVRO, + 2 + } + }; + } + + @Parameter(index = 3) + private FileFormat fileFormat; + + @Parameter(index = 4) + private int formatVersion; + + @AfterEach + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testFilesMetadataTable() throws ParseException { + createTable("id bigint NOT NULL, category string, data string"); + + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + // verify the metadata tables while the current spec is still unpartitioned + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + Dataset df = loadMetadataTable(tableType); + assertThat(df.schema().getFieldIndex("partition").isEmpty()) + .as("Partition must be skipped") + .isTrue(); + } + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateSpec().addField("data").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + // verify the metadata tables after adding the first partition column + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row(new Object[] {null}), row("b1")), "STRUCT", tableType); + } + + table.updateSpec().addField(Expressions.bucket("category", 8)).commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + // verify the metadata tables after adding the second partition column + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row(null, null), row("b1", null), row("b1", 2)), + "STRUCT", + tableType); + } + + table.updateSpec().removeField("data").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + // verify the metadata tables after dropping the first partition column + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row(null, null), row(null, 2), row("b1", null), row("b1", 2)), + "STRUCT", + tableType); + } + + table.updateSpec().renameField("category_bucket_8", "category_bucket_8_another_name").commit(); + sql("REFRESH TABLE %s", tableName); + + // verify the metadata tables after renaming the second partition column + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row(null, null), row(null, 2), row("b1", null), row("b1", 2)), + "STRUCT", + tableType); + } + } + + @TestTemplate + public void testFilesMetadataTableFilter() throws ParseException { + createTable("id bigint NOT NULL, category string, data string"); + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'false')", tableName, MANIFEST_MERGE_ENABLED); + + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + // verify the metadata tables while the current spec is still unpartitioned + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + Dataset df = loadMetadataTable(tableType); + assertThat(df.schema().getFieldIndex("partition").isEmpty()) + .as("Partition must be skipped") + .isTrue(); + } + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateSpec().addField("data").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + // verify the metadata tables after adding the first partition column + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row("d2")), "STRUCT", tableType, "partition.data = 'd2'"); + } + + table.updateSpec().addField("category").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + // verify the metadata tables after adding the second partition column + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row("d2", null), row("d2", "c2")), + "STRUCT", + tableType, + "partition.data = 'd2'"); + } + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row("d2", "c2")), + "STRUCT", + tableType, + "partition.category = 'c2'"); + } + + table.updateSpec().removeField("data").commit(); + sql("REFRESH TABLE %s", tableName); + + // Verify new partitions do not show up for removed 'partition.data=d2' query + sql("INSERT INTO TABLE %s VALUES (3, 'c3', 'd2')", tableName); + sql("INSERT INTO TABLE %s VALUES (4, 'c4', 'd2')", tableName); + + // Verify new partitions do show up for 'partition.category=c2' query + sql("INSERT INTO TABLE %s VALUES (5, 'c2', 'd5')", tableName); + + // no new partition should show up for 'data' partition query as partition field has been + // removed + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row("d2", null), row("d2", "c2")), + "STRUCT", + tableType, + "partition.data = 'd2'"); + } + // new partition shows up from 'category' partition field query + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row(null, "c2"), row("d2", "c2")), + "STRUCT", + tableType, + "partition.category = 'c2'"); + } + + table.updateSpec().renameField("category", "category_another_name").commit(); + sql("REFRESH TABLE %s", tableName); + + // Verify new partitions do show up for 'category=c2' query + sql("INSERT INTO TABLE %s VALUES (6, 'c2', 'd6')", tableName); + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row(null, "c2"), row(null, "c2"), row("d2", "c2")), + "STRUCT", + tableType, + "partition.category_another_name = 'c2'"); + } + } + + @TestTemplate + public void testEntriesMetadataTable() throws ParseException { + createTable("id bigint NOT NULL, category string, data string"); + + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + // verify the metadata tables while the current spec is still unpartitioned + for (MetadataTableType tableType : Arrays.asList(ENTRIES, ALL_ENTRIES)) { + Dataset df = loadMetadataTable(tableType); + StructType dataFileType = (StructType) df.schema().apply("data_file").dataType(); + assertThat(dataFileType.getFieldIndex("").isEmpty()).as("Partition must be skipped").isTrue(); + } + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateSpec().addField("data").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + // verify the metadata tables after adding the first partition column + for (MetadataTableType tableType : Arrays.asList(ENTRIES, ALL_ENTRIES)) { + assertPartitions( + ImmutableList.of(row(new Object[] {null}), row("b1")), "STRUCT", tableType); + } + + table.updateSpec().addField(Expressions.bucket("category", 8)).commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + // verify the metadata tables after adding the second partition column + for (MetadataTableType tableType : Arrays.asList(ENTRIES, ALL_ENTRIES)) { + assertPartitions( + ImmutableList.of(row(null, null), row("b1", null), row("b1", 2)), + "STRUCT", + tableType); + } + + table.updateSpec().removeField("data").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + // verify the metadata tables after dropping the first partition column + for (MetadataTableType tableType : Arrays.asList(ENTRIES, ALL_ENTRIES)) { + assertPartitions( + ImmutableList.of(row(null, null), row(null, 2), row("b1", null), row("b1", 2)), + "STRUCT", + tableType); + } + + table.updateSpec().renameField("category_bucket_8", "category_bucket_8_another_name").commit(); + sql("REFRESH TABLE %s", tableName); + + // verify the metadata tables after renaming the second partition column + for (MetadataTableType tableType : Arrays.asList(ENTRIES, ALL_ENTRIES)) { + assertPartitions( + ImmutableList.of(row(null, null), row(null, 2), row("b1", null), row("b1", 2)), + "STRUCT", + tableType); + } + } + + @TestTemplate + public void testPartitionsTableAddRemoveFields() throws ParseException { + createTable("id bigint NOT NULL, category string, data string"); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + // verify the metadata tables while the current spec is still unpartitioned + Dataset df = loadMetadataTable(PARTITIONS); + assertThat(df.schema().getFieldIndex("partition").isEmpty()) + .as("Partition must be skipped") + .isTrue(); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateSpec().addField("data").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + // verify the metadata tables after adding the first partition column + assertPartitions( + ImmutableList.of(row(new Object[] {null}), row("d1"), row("d2")), + "STRUCT", + PARTITIONS); + + table.updateSpec().addField("category").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + // verify the metadata tables after adding the second partition column + assertPartitions( + ImmutableList.of( + row(null, null), row("d1", null), row("d1", "c1"), row("d2", null), row("d2", "c2")), + "STRUCT", + PARTITIONS); + + // verify the metadata tables after removing the first partition column + table.updateSpec().removeField("data").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + assertPartitions( + ImmutableList.of( + row(null, null), + row(null, "c1"), + row(null, "c2"), + row("d1", null), + row("d1", "c1"), + row("d2", null), + row("d2", "c2")), + "STRUCT", + PARTITIONS); + } + + @TestTemplate + public void testPartitionsTableRenameFields() throws ParseException { + createTable("id bigint NOT NULL, category string, data string"); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateSpec().addField("data").addField("category").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + assertPartitions( + ImmutableList.of(row("d1", "c1"), row("d2", "c2")), + "STRUCT", + PARTITIONS); + + table.updateSpec().renameField("category", "category_another_name").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + assertPartitions( + ImmutableList.of(row("d1", "c1"), row("d2", "c2")), + "STRUCT", + PARTITIONS); + } + + @TestTemplate + public void testPartitionsTableSwitchFields() throws Exception { + createTable("id bigint NOT NULL, category string, data string"); + + Table table = validationCatalog.loadTable(tableIdent); + + // verify the metadata tables after re-adding the first dropped column in the second location + table.updateSpec().addField("data").addField("category").commit(); + sql("REFRESH TABLE %s", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + assertPartitions( + ImmutableList.of(row("d1", "c1"), row("d2", "c2")), + "STRUCT", + PARTITIONS); + + table.updateSpec().removeField("data").commit(); + sql("REFRESH TABLE %s", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + assertPartitions( + ImmutableList.of(row(null, "c1"), row(null, "c2"), row("d1", "c1"), row("d2", "c2")), + "STRUCT", + PARTITIONS); + + table.updateSpec().addField("data").commit(); + sql("REFRESH TABLE %s", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + sql("INSERT INTO TABLE %s VALUES (3, 'c3', 'd3')", tableName); + + if (formatVersion == 1) { + assertPartitions( + ImmutableList.of( + row(null, "c1", null), + row(null, "c1", "d1"), + row(null, "c2", null), + row(null, "c2", "d2"), + row(null, "c3", "d3"), + row("d1", "c1", null), + row("d2", "c2", null)), + "STRUCT", + PARTITIONS); + } else { + // In V2 re-adding a former partition field that was part of an older spec will not change its + // name or its + // field ID either, thus values will be collapsed into a single common column (as opposed to + // V1 where any new + // partition field addition will result in a new column in this metadata table) + assertPartitions( + ImmutableList.of( + row(null, "c1"), row(null, "c2"), row("d1", "c1"), row("d2", "c2"), row("d3", "c3")), + "STRUCT", + PARTITIONS); + } + } + + @TestTemplate + public void testPartitionTableFilterAddRemoveFields() throws ParseException { + // Create un-partitioned table + createTable("id bigint NOT NULL, category string, data string"); + + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + // Partition Table with one partition column + Table table = validationCatalog.loadTable(tableIdent); + table.updateSpec().addField("data").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + assertPartitions( + ImmutableList.of(row("d2")), "STRUCT", PARTITIONS, "partition.data = 'd2'"); + + // Partition Table with two partition column + table.updateSpec().addField("category").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + assertPartitions( + ImmutableList.of(row("d2", null), row("d2", "c2")), + "STRUCT", + PARTITIONS, + "partition.data = 'd2'"); + assertPartitions( + ImmutableList.of(row("d2", "c2")), + "STRUCT", + PARTITIONS, + "partition.category = 'c2'"); + + // Partition Table with first partition column removed + table.updateSpec().removeField("data").commit(); + sql("REFRESH TABLE %s", tableName); + + sql("INSERT INTO TABLE %s VALUES (3, 'c3', 'd2')", tableName); + sql("INSERT INTO TABLE %s VALUES (4, 'c4', 'd2')", tableName); + sql("INSERT INTO TABLE %s VALUES (5, 'c2', 'd5')", tableName); + assertPartitions( + ImmutableList.of(row("d2", null), row("d2", "c2")), + "STRUCT", + PARTITIONS, + "partition.data = 'd2'"); + assertPartitions( + ImmutableList.of(row(null, "c2"), row("d2", "c2")), + "STRUCT", + PARTITIONS, + "partition.category = 'c2'"); + } + + @TestTemplate + public void testPartitionTableFilterSwitchFields() throws Exception { + // Re-added partition fields currently not re-associated: + // https://github.com/apache/iceberg/issues/4292 + // In V1, dropped partition fields show separately when field is re-added + // In V2, re-added field currently conflicts with its deleted form + assumeThat(formatVersion).isEqualTo(1); + + createTable("id bigint NOT NULL, category string, data string"); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + Table table = validationCatalog.loadTable(tableIdent); + + // Two partition columns + table.updateSpec().addField("data").addField("category").commit(); + sql("REFRESH TABLE %s", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + // Drop first partition column + table.updateSpec().removeField("data").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + // Re-add first partition column at the end + table.updateSpec().addField("data").commit(); + sql("REFRESH TABLE %s", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + assertPartitions( + ImmutableList.of(row(null, "c2", null), row(null, "c2", "d2"), row("d2", "c2", null)), + "STRUCT", + PARTITIONS, + "partition.category = 'c2'"); + + assertPartitions( + ImmutableList.of(row(null, "c1", "d1")), + "STRUCT", + PARTITIONS, + "partition.data = 'd1'"); + } + + @TestTemplate + public void testPartitionsTableFilterRenameFields() throws ParseException { + createTable("id bigint NOT NULL, category string, data string"); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateSpec().addField("data").addField("category").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + table.updateSpec().renameField("category", "category_another_name").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + assertPartitions( + ImmutableList.of(row("d1", "c1")), + "STRUCT", + PARTITIONS, + "partition.category_another_name = 'c1'"); + } + + @TestTemplate + public void testMetadataTablesWithUnknownTransforms() { + createTable("id bigint NOT NULL, category string, data string"); + + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + PartitionSpec unknownSpec = + TestHelpers.newExpectedSpecBuilder() + .withSchema(table.schema()) + .withSpecId(1) + .addField("zero", 1, "id_zero") + .build(); + + // replace the table spec to include an unknown transform + TableOperations ops = ((HasTableOperations) table).operations(); + TableMetadata base = ops.current(); + ops.commit(base, base.updatePartitionSpec(unknownSpec)); + + sql("REFRESH TABLE %s", tableName); + + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES, ENTRIES, ALL_ENTRIES)) { + assertThatThrownBy(() -> loadMetadataTable(tableType)) + .isInstanceOf(ValidationException.class) + .hasMessage("Cannot build table partition type, unknown transforms: [zero]"); + } + } + + @TestTemplate + public void testPartitionColumnNamedPartition() { + sql( + "CREATE TABLE %s (id int, partition int) USING iceberg PARTITIONED BY (partition)", + tableName); + sql("INSERT INTO %s VALUES (1, 1), (2, 1), (3, 2), (2, 2)", tableName); + List expected = ImmutableList.of(row(1, 1), row(2, 1), row(3, 2), row(2, 2)); + assertEquals("Should return all expected rows", expected, sql("SELECT * FROM %s", tableName)); + assertThat(sql("SELECT * FROM %s.files", tableName)).hasSize(2); + } + + private void assertPartitions( + List expectedPartitions, String expectedTypeAsString, MetadataTableType tableType) + throws ParseException { + assertPartitions(expectedPartitions, expectedTypeAsString, tableType, null); + } + + private void assertPartitions( + List expectedPartitions, + String expectedTypeAsString, + MetadataTableType tableType, + String filter) + throws ParseException { + Dataset df = loadMetadataTable(tableType); + if (filter != null) { + df = df.filter(filter); + } + + DataType expectedType = spark.sessionState().sqlParser().parseDataType(expectedTypeAsString); + switch (tableType) { + case PARTITIONS: + case FILES: + case ALL_DATA_FILES: + DataType actualFilesType = df.schema().apply("partition").dataType(); + assertThat(actualFilesType).as("Partition type must match").isEqualTo(expectedType); + break; + + case ENTRIES: + case ALL_ENTRIES: + StructType dataFileType = (StructType) df.schema().apply("data_file").dataType(); + DataType actualEntriesType = dataFileType.apply("partition").dataType(); + assertThat(actualEntriesType).as("Partition type must match").isEqualTo(expectedType); + break; + + default: + throw new UnsupportedOperationException("Unsupported metadata table type: " + tableType); + } + + switch (tableType) { + case PARTITIONS: + case FILES: + case ALL_DATA_FILES: + List actualFilesPartitions = + df.orderBy("partition").select("partition.*").collectAsList(); + assertEquals( + "Partitions must match", expectedPartitions, rowsToJava(actualFilesPartitions)); + break; + + case ENTRIES: + case ALL_ENTRIES: + List actualEntriesPartitions = + df.orderBy("data_file.partition").select("data_file.partition.*").collectAsList(); + assertEquals( + "Partitions must match", expectedPartitions, rowsToJava(actualEntriesPartitions)); + break; + + default: + throw new UnsupportedOperationException("Unsupported metadata table type: " + tableType); + } + } + + private Dataset loadMetadataTable(MetadataTableType tableType) { + return spark.read().format("iceberg").load(tableName + "." + tableType.name()); + } + + private void createTable(String schema) { + sql( + "CREATE TABLE %s (%s) USING iceberg TBLPROPERTIES ('%s' '%s', '%s' '%d')", + tableName, schema, DEFAULT_FILE_FORMAT, fileFormat.name(), FORMAT_VERSION, formatVersion); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestParquetScan.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestParquetScan.java new file mode 100644 index 000000000000..ebeed62acce4 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestParquetScan.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.Files.localOutput; +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.apache.spark.sql.functions.monotonically_increasing_id; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assumptions.assumeThat; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.UUID; +import org.apache.avro.generic.GenericData; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.spark.data.ParameterizedAvroDataTest; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestParquetScan extends ParameterizedAvroDataTest { + private static final Configuration CONF = new Configuration(); + + private static SparkSession spark = null; + + @BeforeAll + public static void startSpark() { + TestParquetScan.spark = SparkSession.builder().master("local[2]").getOrCreate(); + } + + @AfterAll + public static void stopSpark() { + SparkSession currentSpark = TestParquetScan.spark; + TestParquetScan.spark = null; + currentSpark.stop(); + } + + @TempDir private Path temp; + + @Parameter private boolean vectorized; + + @Parameters(name = "vectorized = {0}") + public static Collection parameters() { + return Arrays.asList(false, true); + } + + @Override + protected void writeAndValidate(Schema schema) throws IOException { + assumeThat( + TypeUtil.find( + schema, + type -> type.isMapType() && type.asMapType().keyType() != Types.StringType.get())) + .as("Cannot handle non-string map keys in parquet-avro") + .isNull(); + + assertThat(vectorized).as("should not be null").isNotNull(); + Table table = createTable(schema); + + // Important: use the table's schema for the rest of the test + // When tables are created, the column ids are reassigned. + List expected = RandomData.generateList(table.schema(), 100, 1L); + writeRecords(table, expected); + + configureVectorization(table); + + Dataset df = spark.read().format("iceberg").load(table.location()); + + List rows = df.collectAsList(); + assertThat(rows).as("Should contain 100 rows").hasSize(100); + + for (int i = 0; i < expected.size(); i += 1) { + TestHelpers.assertEqualsSafe(table.schema().asStruct(), expected.get(i), rows.get(i)); + } + } + + @TestTemplate + public void testEmptyTableProjection() throws IOException { + Types.StructType structType = + Types.StructType.of( + required(100, "id", Types.LongType.get()), + optional(101, "data", Types.StringType.get()), + required(102, "b", Types.BooleanType.get()), + optional(103, "i", Types.IntegerType.get())); + Table table = createTable(new Schema(structType.fields())); + + List expected = RandomData.generateList(table.schema(), 100, 1L); + writeRecords(table, expected); + + configureVectorization(table); + + List rows = + spark + .read() + .format("iceberg") + .load(table.location()) + .select(monotonically_increasing_id()) + .collectAsList(); + assertThat(rows).hasSize(100); + } + + private Table createTable(Schema schema) throws IOException { + File parent = temp.resolve("parquet").toFile(); + File location = new File(parent, "test"); + HadoopTables tables = new HadoopTables(CONF); + return tables.create(schema, PartitionSpec.unpartitioned(), location.toString()); + } + + private void writeRecords(Table table, List records) throws IOException { + File dataFolder = new File(table.location(), "data"); + dataFolder.mkdirs(); + + File parquetFile = + new File(dataFolder, FileFormat.PARQUET.addExtension(UUID.randomUUID().toString())); + + try (FileAppender writer = + Parquet.write(localOutput(parquetFile)).schema(table.schema()).build()) { + writer.addAll(records); + } + + DataFile file = + DataFiles.builder(PartitionSpec.unpartitioned()) + .withFileSizeInBytes(parquetFile.length()) + .withPath(parquetFile.toString()) + .withRecordCount(100) + .build(); + + table.newAppend().appendFile(file).commit(); + } + + private void configureVectorization(Table table) { + table + .updateProperties() + .set(TableProperties.PARQUET_VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .commit(); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestPartitionPruning.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestPartitionPruning.java new file mode 100644 index 000000000000..9464f687b0eb --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestPartitionPruning.java @@ -0,0 +1,478 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.PlanningMode.DISTRIBUTED; +import static org.apache.iceberg.PlanningMode.LOCAL; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.File; +import java.io.IOException; +import java.net.URI; +import java.nio.file.Files; +import java.sql.Timestamp; +import java.time.Instant; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataInputStream; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.fs.RawLocalFileSystem; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.PlanningMode; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.expressions.Literal; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.transforms.Transforms; +import org.apache.iceberg.types.Types; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.unsafe.types.UTF8String; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestPartitionPruning { + + private static final Configuration CONF = new Configuration(); + private static final HadoopTables TABLES = new HadoopTables(CONF); + + @Parameters(name = "format = {0}, vectorized = {1}, planningMode = {2}") + public static Object[][] parameters() { + return new Object[][] { + {"parquet", false, DISTRIBUTED}, + {"parquet", true, LOCAL}, + {"avro", false, DISTRIBUTED}, + {"orc", false, LOCAL}, + {"orc", true, DISTRIBUTED} + }; + } + + @Parameter(index = 0) + private String format; + + @Parameter(index = 1) + private boolean vectorized; + + @Parameter(index = 2) + private PlanningMode planningMode; + + private static SparkSession spark = null; + private static JavaSparkContext sparkContext = null; + + private static final Function BUCKET_FUNC = + Transforms.bucket(3).bind(Types.IntegerType.get()); + private static final Function TRUNCATE_FUNC = + Transforms.truncate(5).bind(Types.StringType.get()); + private static final Function HOUR_FUNC = + Transforms.hour().bind(Types.TimestampType.withoutZone()); + + @BeforeAll + public static void startSpark() { + TestPartitionPruning.spark = SparkSession.builder().master("local[2]").getOrCreate(); + TestPartitionPruning.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + + String optionKey = String.format("fs.%s.impl", CountOpenLocalFileSystem.scheme); + CONF.set(optionKey, CountOpenLocalFileSystem.class.getName()); + spark.conf().set(optionKey, CountOpenLocalFileSystem.class.getName()); + spark.conf().set("spark.sql.session.timeZone", "UTC"); + spark.udf().register("bucket3", (Integer num) -> BUCKET_FUNC.apply(num), DataTypes.IntegerType); + spark + .udf() + .register("truncate5", (String str) -> TRUNCATE_FUNC.apply(str), DataTypes.StringType); + // NOTE: date transforms take the type long, not Timestamp + spark + .udf() + .register( + "hour", + (Timestamp ts) -> + HOUR_FUNC.apply( + org.apache.spark.sql.catalyst.util.DateTimeUtils.fromJavaTimestamp(ts)), + DataTypes.IntegerType); + } + + @AfterAll + public static void stopSpark() { + SparkSession currentSpark = TestPartitionPruning.spark; + TestPartitionPruning.spark = null; + currentSpark.stop(); + } + + private static final Schema LOG_SCHEMA = + new Schema( + Types.NestedField.optional(1, "id", Types.IntegerType.get()), + Types.NestedField.optional(2, "date", Types.StringType.get()), + Types.NestedField.optional(3, "level", Types.StringType.get()), + Types.NestedField.optional(4, "message", Types.StringType.get()), + Types.NestedField.optional(5, "timestamp", Types.TimestampType.withZone())); + + private static final List LOGS = + ImmutableList.of( + LogMessage.debug("2020-02-02", "debug event 1", getInstant("2020-02-02T00:00:00")), + LogMessage.info("2020-02-02", "info event 1", getInstant("2020-02-02T01:00:00")), + LogMessage.debug("2020-02-02", "debug event 2", getInstant("2020-02-02T02:00:00")), + LogMessage.info("2020-02-03", "info event 2", getInstant("2020-02-03T00:00:00")), + LogMessage.debug("2020-02-03", "debug event 3", getInstant("2020-02-03T01:00:00")), + LogMessage.info("2020-02-03", "info event 3", getInstant("2020-02-03T02:00:00")), + LogMessage.error("2020-02-03", "error event 1", getInstant("2020-02-03T03:00:00")), + LogMessage.debug("2020-02-04", "debug event 4", getInstant("2020-02-04T01:00:00")), + LogMessage.warn("2020-02-04", "warn event 1", getInstant("2020-02-04T02:00:00")), + LogMessage.debug("2020-02-04", "debug event 5", getInstant("2020-02-04T03:00:00"))); + + private static Instant getInstant(String timestampWithoutZone) { + Long epochMicros = + (Long) Literal.of(timestampWithoutZone).to(Types.TimestampType.withoutZone()).value(); + return Instant.ofEpochMilli(TimeUnit.MICROSECONDS.toMillis(epochMicros)); + } + + @TempDir private java.nio.file.Path temp; + + private final PartitionSpec spec = + PartitionSpec.builderFor(LOG_SCHEMA) + .identity("date") + .identity("level") + .bucket("id", 3) + .truncate("message", 5) + .hour("timestamp") + .build(); + + @TestTemplate + public void testPartitionPruningIdentityString() { + String filterCond = "date >= '2020-02-03' AND level = 'DEBUG'"; + Predicate partCondition = + (Row r) -> { + String date = r.getString(0); + String level = r.getString(1); + return date.compareTo("2020-02-03") >= 0 && level.equals("DEBUG"); + }; + + runTest(filterCond, partCondition); + } + + @TestTemplate + public void testPartitionPruningBucketingInteger() { + final int[] ids = new int[] {LOGS.get(3).getId(), LOGS.get(7).getId()}; + String condForIds = + Arrays.stream(ids).mapToObj(String::valueOf).collect(Collectors.joining(",", "(", ")")); + String filterCond = "id in " + condForIds; + Predicate partCondition = + (Row r) -> { + int bucketId = r.getInt(2); + Set buckets = + Arrays.stream(ids).map(BUCKET_FUNC::apply).boxed().collect(Collectors.toSet()); + return buckets.contains(bucketId); + }; + + runTest(filterCond, partCondition); + } + + @TestTemplate + public void testPartitionPruningTruncatedString() { + String filterCond = "message like 'info event%'"; + Predicate partCondition = + (Row r) -> { + String truncatedMessage = r.getString(3); + return truncatedMessage.equals("info "); + }; + + runTest(filterCond, partCondition); + } + + @TestTemplate + public void testPartitionPruningTruncatedStringComparingValueShorterThanPartitionValue() { + String filterCond = "message like 'inf%'"; + Predicate partCondition = + (Row r) -> { + String truncatedMessage = r.getString(3); + return truncatedMessage.startsWith("inf"); + }; + + runTest(filterCond, partCondition); + } + + @TestTemplate + public void testPartitionPruningHourlyPartition() { + String filterCond; + if (spark.version().startsWith("2")) { + // Looks like from Spark 2 we need to compare timestamp with timestamp to push down the + // filter. + filterCond = "timestamp >= to_timestamp('2020-02-03T01:00:00')"; + } else { + filterCond = "timestamp >= '2020-02-03T01:00:00'"; + } + Predicate partCondition = + (Row r) -> { + int hourValue = r.getInt(4); + Instant instant = getInstant("2020-02-03T01:00:00"); + Integer hourValueToFilter = + HOUR_FUNC.apply(TimeUnit.MILLISECONDS.toMicros(instant.toEpochMilli())); + return hourValue >= hourValueToFilter; + }; + + runTest(filterCond, partCondition); + } + + private void runTest(String filterCond, Predicate partCondition) { + File originTableLocation = createTempDir(); + assertThat(originTableLocation).as("Temp folder should exist").exists(); + + Table table = createTable(originTableLocation); + Dataset logs = createTestDataset(); + saveTestDatasetToTable(logs, table); + + List expected = + logs.select("id", "date", "level", "message", "timestamp") + .filter(filterCond) + .orderBy("id") + .collectAsList(); + assertThat(expected).as("Expected rows should not be empty").isNotEmpty(); + + // remove records which may be recorded during storing to table + CountOpenLocalFileSystem.resetRecordsInPathPrefix(originTableLocation.getAbsolutePath()); + + List actual = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(table.location()) + .select("id", "date", "level", "message", "timestamp") + .filter(filterCond) + .orderBy("id") + .collectAsList(); + assertThat(actual).as("Actual rows should not be empty").isNotEmpty(); + + assertThat(actual).as("Rows should match").isEqualTo(expected); + + assertAccessOnDataFiles(originTableLocation, table, partCondition); + } + + private File createTempDir() { + try { + return Files.createTempDirectory(temp, "junit").toFile(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private Table createTable(File originTableLocation) { + String trackedTableLocation = CountOpenLocalFileSystem.convertPath(originTableLocation); + Map properties = + ImmutableMap.of( + TableProperties.DEFAULT_FILE_FORMAT, format, + TableProperties.DATA_PLANNING_MODE, planningMode.modeName(), + TableProperties.DELETE_PLANNING_MODE, planningMode.modeName()); + return TABLES.create(LOG_SCHEMA, spec, properties, trackedTableLocation); + } + + private Dataset createTestDataset() { + List rows = + LOGS.stream() + .map( + logMessage -> { + Object[] underlying = + new Object[] { + logMessage.getId(), + UTF8String.fromString(logMessage.getDate()), + UTF8String.fromString(logMessage.getLevel()), + UTF8String.fromString(logMessage.getMessage()), + // discard the nanoseconds part to simplify + TimeUnit.MILLISECONDS.toMicros(logMessage.getTimestamp().toEpochMilli()) + }; + return new GenericInternalRow(underlying); + }) + .collect(Collectors.toList()); + + JavaRDD rdd = sparkContext.parallelize(rows); + Dataset df = + spark.internalCreateDataFrame( + JavaRDD.toRDD(rdd), SparkSchemaUtil.convert(LOG_SCHEMA), false); + + return df.selectExpr("id", "date", "level", "message", "timestamp") + .selectExpr( + "id", + "date", + "level", + "message", + "timestamp", + "bucket3(id) AS bucket_id", + "truncate5(message) AS truncated_message", + "hour(timestamp) AS ts_hour"); + } + + private void saveTestDatasetToTable(Dataset logs, Table table) { + logs.orderBy("date", "level", "bucket_id", "truncated_message", "ts_hour") + .select("id", "date", "level", "message", "timestamp") + .write() + .format("iceberg") + .mode("append") + .save(table.location()); + } + + private void assertAccessOnDataFiles( + File originTableLocation, Table table, Predicate partCondition) { + // only use files in current table location to avoid side-effects on concurrent test runs + Set readFilesInQuery = + CountOpenLocalFileSystem.pathToNumOpenCalled.keySet().stream() + .filter(path -> path.startsWith(originTableLocation.getAbsolutePath())) + .collect(Collectors.toSet()); + + List files = + spark.read().format("iceberg").load(table.location() + "#files").collectAsList(); + + Set filesToRead = extractFilePathsMatchingConditionOnPartition(files, partCondition); + Set filesToNotRead = extractFilePathsNotIn(files, filesToRead); + + // Just to be sure, they should be mutually exclusive. + assertThat(filesToRead).doesNotContainAnyElementsOf(filesToNotRead); + + assertThat(filesToNotRead).as("The query should prune some data files.").isNotEmpty(); + + // We don't check "all" data files bound to the condition are being read, as data files can be + // pruned on + // other conditions like lower/upper bound of columns. + assertThat(filesToRead) + .as( + "Some of data files in partition range should be read. " + + "Read files in query: " + + readFilesInQuery + + " / data files in partition range: " + + filesToRead) + .containsAnyElementsOf(readFilesInQuery); + + // Data files which aren't bound to the condition shouldn't be read. + assertThat(filesToNotRead) + .as( + "Data files outside of partition range should not be read. " + + "Read files in query: " + + readFilesInQuery + + " / data files outside of partition range: " + + filesToNotRead) + .doesNotContainAnyElementsOf(readFilesInQuery); + } + + private Set extractFilePathsMatchingConditionOnPartition( + List files, Predicate condition) { + // idx 1: file_path, idx 3: partition + return files.stream() + .filter( + r -> { + Row partition = r.getStruct(4); + return condition.test(partition); + }) + .map(r -> CountOpenLocalFileSystem.stripScheme(r.getString(1))) + .collect(Collectors.toSet()); + } + + private Set extractFilePathsNotIn(List files, Set filePaths) { + Set allFilePaths = + files.stream() + .map(r -> CountOpenLocalFileSystem.stripScheme(r.getString(1))) + .collect(Collectors.toSet()); + return Sets.newHashSet(Sets.symmetricDifference(allFilePaths, filePaths)); + } + + public static class CountOpenLocalFileSystem extends RawLocalFileSystem { + public static String scheme = + String.format("TestIdentityPartitionData%dfs", new Random().nextInt()); + public static Map pathToNumOpenCalled = Maps.newConcurrentMap(); + + public static String convertPath(String absPath) { + return scheme + "://" + absPath; + } + + public static String convertPath(File file) { + return convertPath(file.getAbsolutePath()); + } + + public static String stripScheme(String pathWithScheme) { + if (!pathWithScheme.startsWith(scheme + ":")) { + throw new IllegalArgumentException("Received unexpected path: " + pathWithScheme); + } + + int idxToCut = scheme.length() + 1; + while (pathWithScheme.charAt(idxToCut) == '/') { + idxToCut++; + } + + // leave the last '/' + idxToCut--; + + return pathWithScheme.substring(idxToCut); + } + + public static void resetRecordsInPathPrefix(String pathPrefix) { + pathToNumOpenCalled.keySet().stream() + .filter(p -> p.startsWith(pathPrefix)) + .forEach(key -> pathToNumOpenCalled.remove(key)); + } + + @Override + public URI getUri() { + return URI.create(scheme + ":///"); + } + + @Override + public String getScheme() { + return scheme; + } + + @Override + public FSDataInputStream open(Path f, int bufferSize) throws IOException { + String path = f.toUri().getPath(); + pathToNumOpenCalled.compute( + path, + (ignored, v) -> { + if (v == null) { + return 1L; + } else { + return v + 1; + } + }); + return super.open(f, bufferSize); + } + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues.java new file mode 100644 index 000000000000..5c218f21c47e --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues.java @@ -0,0 +1,509 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.spark.data.TestVectorizedOrcDataReader.temp; +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assumptions.assumeThat; + +import java.io.File; +import java.nio.file.Path; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.apache.avro.generic.GenericData; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.Files; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.types.Types; +import org.apache.spark.api.java.function.MapFunction; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestPartitionValues { + @Parameters(name = "format = {0}, vectorized = {1}") + public static Object[][] parameters() { + return new Object[][] { + {"parquet", false}, + {"parquet", true}, + {"avro", false}, + {"orc", false}, + {"orc", true} + }; + } + + private static final Schema SUPPORTED_PRIMITIVES = + new Schema( + required(100, "id", Types.LongType.get()), + required(101, "data", Types.StringType.get()), + required(102, "b", Types.BooleanType.get()), + required(103, "i", Types.IntegerType.get()), + required(104, "l", Types.LongType.get()), + required(105, "f", Types.FloatType.get()), + required(106, "d", Types.DoubleType.get()), + required(107, "date", Types.DateType.get()), + required(108, "ts", Types.TimestampType.withZone()), + required(110, "s", Types.StringType.get()), + required(113, "bytes", Types.BinaryType.get()), + required(114, "dec_9_0", Types.DecimalType.of(9, 0)), + required(115, "dec_11_2", Types.DecimalType.of(11, 2)), + required(116, "dec_38_10", Types.DecimalType.of(38, 10)) // spark's maximum precision + ); + + private static final Schema SIMPLE_SCHEMA = + new Schema( + optional(1, "id", Types.IntegerType.get()), optional(2, "data", Types.StringType.get())); + + private static final PartitionSpec SPEC = + PartitionSpec.builderFor(SIMPLE_SCHEMA).identity("data").build(); + + private static SparkSession spark = null; + + @BeforeAll + public static void startSpark() { + TestPartitionValues.spark = SparkSession.builder().master("local[2]").getOrCreate(); + } + + @AfterAll + public static void stopSpark() { + SparkSession currentSpark = TestPartitionValues.spark; + TestPartitionValues.spark = null; + currentSpark.stop(); + } + + @TempDir private Path temp; + + @Parameter(index = 0) + private String format; + + @Parameter(index = 1) + private boolean vectorized; + + @TestTemplate + public void testNullPartitionValue() throws Exception { + String desc = "null_part"; + File parent = new File(temp.toFile(), desc); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + assertThat(dataFolder.mkdirs()).as("mkdirs should succeed").isTrue(); + + HadoopTables tables = new HadoopTables(spark.sessionState().newHadoopConf()); + Table table = tables.create(SIMPLE_SCHEMA, SPEC, location.toString()); + table.updateProperties().set(TableProperties.DEFAULT_FILE_FORMAT, format).commit(); + + List expected = + Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c"), + new SimpleRecord(4, null)); + + Dataset df = spark.createDataFrame(expected, SimpleRecord.class); + + df.select("id", "data") + .write() + .format("iceberg") + .mode(SaveMode.Append) + .save(location.toString()); + + Dataset result = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(location.toString()); + + List actual = + result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + + assertThat(actual).as("Number of rows should match").hasSameSizeAs(expected); + assertThat(actual).as("Result rows should match").isEqualTo(expected); + } + + @TestTemplate + public void testReorderedColumns() throws Exception { + String desc = "reorder_columns"; + File parent = new File(temp.toFile(), desc); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + assertThat(dataFolder.mkdirs()).as("mkdirs should succeed").isTrue(); + + HadoopTables tables = new HadoopTables(spark.sessionState().newHadoopConf()); + Table table = tables.create(SIMPLE_SCHEMA, SPEC, location.toString()); + table.updateProperties().set(TableProperties.DEFAULT_FILE_FORMAT, format).commit(); + + List expected = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + + Dataset df = spark.createDataFrame(expected, SimpleRecord.class); + + df.select("data", "id") + .write() + .format("iceberg") + .mode(SaveMode.Append) + .option(SparkWriteOptions.CHECK_ORDERING, "false") + .save(location.toString()); + + Dataset result = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(location.toString()); + + List actual = + result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + + assertThat(actual).as("Number of rows should match").hasSameSizeAs(expected); + assertThat(actual).as("Result rows should match").isEqualTo(expected); + } + + @TestTemplate + public void testReorderedColumnsNoNullability() throws Exception { + String desc = "reorder_columns_no_nullability"; + File parent = new File(temp.toFile(), desc); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + assertThat(dataFolder.mkdirs()).as("mkdirs should succeed").isTrue(); + + HadoopTables tables = new HadoopTables(spark.sessionState().newHadoopConf()); + Table table = tables.create(SIMPLE_SCHEMA, SPEC, location.toString()); + table.updateProperties().set(TableProperties.DEFAULT_FILE_FORMAT, format).commit(); + + List expected = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + + Dataset df = spark.createDataFrame(expected, SimpleRecord.class); + + df.select("data", "id") + .write() + .format("iceberg") + .mode(SaveMode.Append) + .option(SparkWriteOptions.CHECK_ORDERING, "false") + .option(SparkWriteOptions.CHECK_NULLABILITY, "false") + .save(location.toString()); + + Dataset result = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(location.toString()); + + List actual = + result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + + assertThat(actual).as("Number of rows should match").hasSameSizeAs(expected); + assertThat(actual).as("Result rows should match").isEqualTo(expected); + } + + @TestTemplate + public void testPartitionValueTypes() throws Exception { + String[] columnNames = + new String[] { + "b", "i", "l", "f", "d", "date", "ts", "s", "bytes", "dec_9_0", "dec_11_2", "dec_38_10" + }; + + HadoopTables tables = new HadoopTables(spark.sessionState().newHadoopConf()); + + // create a table around the source data + String sourceLocation = temp.resolve("source_table").toString(); + Table source = tables.create(SUPPORTED_PRIMITIVES, sourceLocation); + + // write out an Avro data file with all of the data types for source data + List expected = RandomData.generateList(source.schema(), 2, 128735L); + File avroData = File.createTempFile("data", ".avro", temp.toFile()); + assertThat(avroData.delete()).isTrue(); + try (FileAppender appender = + Avro.write(Files.localOutput(avroData)).schema(source.schema()).build()) { + appender.addAll(expected); + } + + // add the Avro data file to the source table + source + .newAppend() + .appendFile( + DataFiles.builder(PartitionSpec.unpartitioned()) + .withRecordCount(10) + .withInputFile(Files.localInput(avroData)) + .build()) + .commit(); + + Dataset sourceDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(sourceLocation); + + for (String column : columnNames) { + String desc = "partition_by_" + SUPPORTED_PRIMITIVES.findType(column).toString(); + + File parent = new File(temp.toFile(), desc); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + assertThat(dataFolder.mkdirs()).as("mkdirs should succeed").isTrue(); + + PartitionSpec spec = PartitionSpec.builderFor(SUPPORTED_PRIMITIVES).identity(column).build(); + + Table table = tables.create(SUPPORTED_PRIMITIVES, spec, location.toString()); + table.updateProperties().set(TableProperties.DEFAULT_FILE_FORMAT, format).commit(); + + // disable distribution/ordering and fanout writers to preserve the original ordering + sourceDF + .write() + .format("iceberg") + .mode(SaveMode.Append) + .option(SparkWriteOptions.USE_TABLE_DISTRIBUTION_AND_ORDERING, "false") + .option(SparkWriteOptions.FANOUT_ENABLED, "false") + .save(location.toString()); + + List actual = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(location.toString()) + .collectAsList(); + + assertThat(actual).as("Number of rows should match").hasSameSizeAs(expected); + + for (int i = 0; i < expected.size(); i += 1) { + TestHelpers.assertEqualsSafe( + SUPPORTED_PRIMITIVES.asStruct(), expected.get(i), actual.get(i)); + } + } + } + + @TestTemplate + public void testNestedPartitionValues() throws Exception { + String[] columnNames = + new String[] { + "b", "i", "l", "f", "d", "date", "ts", "s", "bytes", "dec_9_0", "dec_11_2", "dec_38_10" + }; + + HadoopTables tables = new HadoopTables(spark.sessionState().newHadoopConf()); + Schema nestedSchema = new Schema(optional(1, "nested", SUPPORTED_PRIMITIVES.asStruct())); + + // create a table around the source data + String sourceLocation = temp.resolve("source_table").toString(); + Table source = tables.create(nestedSchema, sourceLocation); + + // write out an Avro data file with all of the data types for source data + List expected = RandomData.generateList(source.schema(), 2, 128735L); + File avroData = File.createTempFile("data", ".avro", temp.toFile()); + assertThat(avroData.delete()).isTrue(); + try (FileAppender appender = + Avro.write(Files.localOutput(avroData)).schema(source.schema()).build()) { + appender.addAll(expected); + } + + // add the Avro data file to the source table + source + .newAppend() + .appendFile( + DataFiles.builder(PartitionSpec.unpartitioned()) + .withRecordCount(10) + .withInputFile(Files.localInput(avroData)) + .build()) + .commit(); + + Dataset sourceDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(sourceLocation); + + for (String column : columnNames) { + String desc = "partition_by_" + SUPPORTED_PRIMITIVES.findType(column).toString(); + + File parent = new File(temp.toFile(), desc); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + assertThat(dataFolder.mkdirs()).as("mkdirs should succeed").isTrue(); + + PartitionSpec spec = + PartitionSpec.builderFor(nestedSchema).identity("nested." + column).build(); + + Table table = tables.create(nestedSchema, spec, location.toString()); + table.updateProperties().set(TableProperties.DEFAULT_FILE_FORMAT, format).commit(); + + // disable distribution/ordering and fanout writers to preserve the original ordering + sourceDF + .write() + .format("iceberg") + .mode(SaveMode.Append) + .option(SparkWriteOptions.USE_TABLE_DISTRIBUTION_AND_ORDERING, "false") + .option(SparkWriteOptions.FANOUT_ENABLED, "false") + .save(location.toString()); + + List actual = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(location.toString()) + .collectAsList(); + + assertThat(actual).as("Number of rows should match").hasSameSizeAs(expected); + + for (int i = 0; i < expected.size(); i += 1) { + TestHelpers.assertEqualsSafe(nestedSchema.asStruct(), expected.get(i), actual.get(i)); + } + } + } + + /** + * To verify if WrappedPositionAccessor is generated against a string field within a nested field, + * rather than a Position2Accessor. Or when building the partition path, a ClassCastException is + * thrown with the message like: Cannot cast org.apache.spark.unsafe.types.UTF8String to + * java.lang.CharSequence + */ + @TestTemplate + public void testPartitionedByNestedString() throws Exception { + // schema and partition spec + Schema nestedSchema = + new Schema( + Types.NestedField.required( + 1, + "struct", + Types.StructType.of( + Types.NestedField.required(2, "string", Types.StringType.get())))); + PartitionSpec spec = PartitionSpec.builderFor(nestedSchema).identity("struct.string").build(); + + // create table + HadoopTables tables = new HadoopTables(spark.sessionState().newHadoopConf()); + String baseLocation = temp.resolve("partition_by_nested_string").toString(); + tables.create(nestedSchema, spec, baseLocation); + + // input data frame + StructField[] structFields = { + new StructField( + "struct", + DataTypes.createStructType( + new StructField[] { + new StructField("string", DataTypes.StringType, false, Metadata.empty()) + }), + false, + Metadata.empty()) + }; + + List rows = Lists.newArrayList(); + rows.add(RowFactory.create(RowFactory.create("nested_string_value"))); + Dataset sourceDF = spark.createDataFrame(rows, new StructType(structFields)); + + // write into iceberg + sourceDF.write().format("iceberg").mode(SaveMode.Append).save(baseLocation); + + // verify + List actual = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(baseLocation) + .collectAsList(); + + assertThat(actual).as("Number of rows should match").hasSameSizeAs(rows); + } + + @TestTemplate + public void testReadPartitionColumn() throws Exception { + assumeThat(format).as("Temporary skip ORC").isNotEqualTo("orc"); + + Schema nestedSchema = + new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional( + 2, + "struct", + Types.StructType.of( + Types.NestedField.optional(3, "innerId", Types.LongType.get()), + Types.NestedField.optional(4, "innerName", Types.StringType.get())))); + PartitionSpec spec = + PartitionSpec.builderFor(nestedSchema).identity("struct.innerName").build(); + + // create table + HadoopTables tables = new HadoopTables(spark.sessionState().newHadoopConf()); + String baseLocation = temp.resolve("partition_by_nested_string").toString(); + Table table = tables.create(nestedSchema, spec, baseLocation); + table.updateProperties().set(TableProperties.DEFAULT_FILE_FORMAT, format).commit(); + + // write into iceberg + MapFunction func = + value -> new ComplexRecord(value, new NestedRecord(value, "name_" + value)); + spark + .range(0, 10, 1, 1) + .map(func, Encoders.bean(ComplexRecord.class)) + .write() + .format("iceberg") + .mode(SaveMode.Append) + .save(baseLocation); + + List actual = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(baseLocation) + .select("struct.innerName") + .orderBy("struct.innerName") + .as(Encoders.STRING()) + .collectAsList(); + + assertThat(actual).as("Number of rows should match").hasSize(10); + + List inputRecords = + IntStream.range(0, 10).mapToObj(i -> "name_" + i).collect(Collectors.toList()); + assertThat(actual).as("Read object should be matched").isEqualTo(inputRecords); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestPathIdentifier.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestPathIdentifier.java new file mode 100644 index 000000000000..bb026b2ab2da --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestPathIdentifier.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.Schema; +import org.apache.iceberg.hadoop.HadoopTableOperations; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.PathIdentifier; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.TestBase; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +public class TestPathIdentifier extends TestBase { + + private static final Schema SCHEMA = + new Schema( + required(1, "id", Types.LongType.get()), required(2, "data", Types.StringType.get())); + + @TempDir private Path temp; + private File tableLocation; + private PathIdentifier identifier; + private SparkCatalog sparkCatalog; + + @BeforeEach + public void before() throws IOException { + tableLocation = temp.toFile(); + identifier = new PathIdentifier(tableLocation.getAbsolutePath()); + sparkCatalog = new SparkCatalog(); + sparkCatalog.initialize("test", new CaseInsensitiveStringMap(ImmutableMap.of())); + } + + @AfterEach + public void after() { + tableLocation.delete(); + sparkCatalog = null; + } + + @Test + public void testPathIdentifier() throws TableAlreadyExistsException, NoSuchTableException { + SparkTable table = + (SparkTable) + sparkCatalog.createTable( + identifier, SparkSchemaUtil.convert(SCHEMA), new Transform[0], ImmutableMap.of()); + + assertThat(tableLocation.getAbsolutePath()).isEqualTo(table.table().location()); + assertThat(table.table()).isInstanceOf(BaseTable.class); + assertThat(((BaseTable) table.table()).operations()).isInstanceOf(HadoopTableOperations.class); + + assertThat(table).isEqualTo(sparkCatalog.loadTable(identifier)); + assertThat(sparkCatalog.dropTable(identifier)).isTrue(); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestPositionDeletesTable.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestPositionDeletesTable.java new file mode 100644 index 000000000000..a991d6ccbff7 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestPositionDeletesTable.java @@ -0,0 +1,1621 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.io.File; +import java.io.IOException; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.function.BiFunction; +import java.util.stream.Collectors; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Files; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.MetadataTableType; +import org.apache.iceberg.MetadataTableUtils; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.PartitionField; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Partitioning; +import org.apache.iceberg.PositionDeletesScanTask; +import org.apache.iceberg.ScanTask; +import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.TestHelpers; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.data.FileHelpers; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.deletes.PositionDelete; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.CatalogTestBase; +import org.apache.iceberg.spark.PositionDeletesRewriteCoordinator; +import org.apache.iceberg.spark.ScanTaskSetManager; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkStructLike; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.CharSequenceSet; +import org.apache.iceberg.util.DeleteFileSet; +import org.apache.iceberg.util.Pair; +import org.apache.iceberg.util.StructLikeSet; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.functions; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestPositionDeletesTable extends CatalogTestBase { + + public static final Schema SCHEMA = + new Schema( + Types.NestedField.required(1, "id", Types.IntegerType.get()), + Types.NestedField.required(2, "data", Types.StringType.get())); + private static final Map CATALOG_PROPS = + ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "cache-enabled", "false"); + private static final List NON_PATH_COLS = + ImmutableList.of("file_path", "pos", "row", "partition", "spec_id"); + + @Parameter(index = 3) + private FileFormat format; + + @Parameters(name = "catalogName = {1}, implementation = {2}, config = {3}, fileFormat = {4}") + public static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.HIVE.catalogName(), + SparkCatalogConfig.HIVE.implementation(), + CATALOG_PROPS, + FileFormat.PARQUET + }, + { + SparkCatalogConfig.HIVE.catalogName(), + SparkCatalogConfig.HIVE.implementation(), + CATALOG_PROPS, + FileFormat.AVRO + }, + { + SparkCatalogConfig.HIVE.catalogName(), + SparkCatalogConfig.HIVE.implementation(), + CATALOG_PROPS, + FileFormat.ORC + }, + }; + } + + @TestTemplate + public void testNullRows() throws IOException { + String tableName = "null_rows"; + Table tab = createTable(tableName, SCHEMA, PartitionSpec.unpartitioned()); + + DataFile dFile = dataFile(tab); + tab.newAppend().appendFile(dFile).commit(); + + List> deletes = Lists.newArrayList(); + deletes.add(Pair.of(dFile.location(), 0L)); + deletes.add(Pair.of(dFile.location(), 1L)); + Pair posDeletes = + FileHelpers.writeDeleteFile( + tab, + Files.localOutput(File.createTempFile("junit", null, temp.toFile())), + TestHelpers.Row.of(0), + deletes); + tab.newRowDelta().addDeletes(posDeletes.first()).commit(); + + StructLikeSet actual = actual(tableName, tab); + + List> expectedDeletes = + Lists.newArrayList( + positionDelete(dFile.location(), 0L), positionDelete(dFile.location(), 1L)); + StructLikeSet expected = expected(tab, expectedDeletes, null, posDeletes.first().location()); + + assertThat(actual).as("Position Delete table should contain expected rows").isEqualTo(expected); + dropTable(tableName); + } + + @TestTemplate + public void testPartitionedTable() throws IOException { + // Create table with two partitions + String tableName = "partitioned_table"; + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table tab = createTable(tableName, SCHEMA, spec); + + DataFile dataFileA = dataFile(tab, "a"); + DataFile dataFileB = dataFile(tab, "b"); + + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + // Add position deletes for both partitions + Pair>, DeleteFile> deletesA = deleteFile(tab, dataFileA, "a"); + Pair>, DeleteFile> deletesB = deleteFile(tab, dataFileB, "b"); + + tab.newRowDelta().addDeletes(deletesA.second()).addDeletes(deletesB.second()).commit(); + + // Select deletes from one partition + StructLikeSet actual = actual(tableName, tab, "row.data='b'"); + GenericRecord partitionB = GenericRecord.create(tab.spec().partitionType()); + partitionB.setField("data", "b"); + StructLikeSet expected = + expected(tab, deletesB.first(), partitionB, deletesB.second().location()); + + assertThat(actual).as("Position Delete table should contain expected rows").isEqualTo(expected); + dropTable(tableName); + } + + @TestTemplate + public void testSelect() throws IOException { + // Create table with two partitions + String tableName = "select"; + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table tab = createTable(tableName, SCHEMA, spec); + + DataFile dataFileA = dataFile(tab, "a"); + DataFile dataFileB = dataFile(tab, "b"); + + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + // Add position deletes for both partitions + Pair>, DeleteFile> deletesA = deleteFile(tab, dataFileA, "a"); + Pair>, DeleteFile> deletesB = deleteFile(tab, dataFileB, "b"); + + tab.newRowDelta().addDeletes(deletesA.second()).addDeletes(deletesB.second()).commit(); + + // Select certain columns + Dataset df = + spark + .read() + .format("iceberg") + .load("default." + tableName + ".position_deletes") + .withColumn("input_file", functions.input_file_name()) + .select("row.id", "pos", "delete_file_path", "input_file"); + List actual = rowsToJava(df.collectAsList()); + + // Select cols from expected delete values + List expected = Lists.newArrayList(); + BiFunction, DeleteFile, Object[]> toRow = + (delete, file) -> { + int rowData = delete.get(2, GenericRecord.class).get(0, Integer.class); + long pos = delete.get(1, Long.class); + return row(rowData, pos, file.location(), file.location()); + }; + expected.addAll( + deletesA.first().stream() + .map(d -> toRow.apply(d, deletesA.second())) + .collect(Collectors.toList())); + expected.addAll( + deletesB.first().stream() + .map(d -> toRow.apply(d, deletesB.second())) + .collect(Collectors.toList())); + + // Sort and compare + Comparator comp = + (o1, o2) -> { + int result = Integer.compare((int) o1[0], (int) o2[0]); + if (result != 0) { + return result; + } else { + return ((String) o1[2]).compareTo((String) o2[2]); + } + }; + actual.sort(comp); + expected.sort(comp); + assertThat(actual) + .as("Position Delete table should contain expected rows") + .usingRecursiveComparison() + .isEqualTo(expected); + dropTable(tableName); + } + + @TestTemplate + public void testSplitTasks() throws IOException { + String tableName = "big_table"; + Table tab = createTable(tableName, SCHEMA, PartitionSpec.unpartitioned()); + tab.updateProperties().set("read.split.target-size", "100").commit(); + int records = 500; + + GenericRecord record = GenericRecord.create(tab.schema()); + List dataRecords = Lists.newArrayList(); + for (int i = 0; i < records; i++) { + dataRecords.add(record.copy("id", i, "data", String.valueOf(i))); + } + DataFile dFile = + FileHelpers.writeDataFile( + tab, + Files.localOutput(File.createTempFile("junit", null, temp.toFile())), + org.apache.iceberg.TestHelpers.Row.of(), + dataRecords); + tab.newAppend().appendFile(dFile).commit(); + + List> deletes = Lists.newArrayList(); + for (long i = 0; i < records; i++) { + deletes.add(positionDelete(tab.schema(), dFile.location(), i, (int) i, String.valueOf(i))); + } + DeleteFile posDeletes = + FileHelpers.writePosDeleteFile( + tab, + Files.localOutput(File.createTempFile("junit", null, temp.toFile())), + TestHelpers.Row.of(0), + deletes); + tab.newRowDelta().addDeletes(posDeletes).commit(); + + Table deleteTable = + MetadataTableUtils.createMetadataTableInstance(tab, MetadataTableType.POSITION_DELETES); + + if (format.equals(FileFormat.AVRO)) { + assertThat(deleteTable.newBatchScan().planTasks()) + .as("Position delete scan should produce more than one split") + .hasSizeGreaterThan(1); + } else { + assertThat(deleteTable.newBatchScan().planTasks()) + .as("Position delete scan should produce one split") + .hasSize(1); + } + + StructLikeSet actual = actual(tableName, tab); + StructLikeSet expected = expected(tab, deletes, null, posDeletes.location()); + + assertThat(actual).as("Position Delete table should contain expected rows").isEqualTo(expected); + dropTable(tableName); + } + + @TestTemplate + public void testPartitionFilter() throws IOException { + // Create table with two partitions + String tableName = "partition_filter"; + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table tab = createTable(tableName, SCHEMA, spec); + Table deletesTab = + MetadataTableUtils.createMetadataTableInstance(tab, MetadataTableType.POSITION_DELETES); + + DataFile dataFileA = dataFile(tab, "a"); + DataFile dataFileB = dataFile(tab, "b"); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + // Add position deletes for both partitions + Pair>, DeleteFile> deletesA = deleteFile(tab, dataFileA, "a"); + Pair>, DeleteFile> deletesB = deleteFile(tab, dataFileA, "b"); + + tab.newRowDelta().addDeletes(deletesA.second()).addDeletes(deletesB.second()).commit(); + + // Prepare expected values + GenericRecord partitionRecordTemplate = GenericRecord.create(tab.spec().partitionType()); + Record partitionA = partitionRecordTemplate.copy("data", "a"); + Record partitionB = partitionRecordTemplate.copy("data", "b"); + StructLikeSet expectedA = + expected(tab, deletesA.first(), partitionA, deletesA.second().location()); + StructLikeSet expectedB = + expected(tab, deletesB.first(), partitionB, deletesB.second().location()); + StructLikeSet allExpected = StructLikeSet.create(deletesTab.schema().asStruct()); + allExpected.addAll(expectedA); + allExpected.addAll(expectedB); + + // Select deletes from all partitions + StructLikeSet actual = actual(tableName, tab); + assertThat(actual) + .as("Position Delete table should contain expected rows") + .isEqualTo(allExpected); + + // Select deletes from one partition + StructLikeSet actual2 = actual(tableName, tab, "partition.data = 'a' AND pos >= 0"); + + assertThat(actual2) + .as("Position Delete table should contain expected rows") + .isEqualTo(expectedA); + dropTable(tableName); + } + + @TestTemplate + public void testPartitionTransformFilter() throws IOException { + // Create table with two partitions + String tableName = "partition_filter"; + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).truncate("data", 1).build(); + Table tab = createTable(tableName, SCHEMA, spec); + Table deletesTable = + MetadataTableUtils.createMetadataTableInstance(tab, MetadataTableType.POSITION_DELETES); + + DataFile dataFileA = dataFile(tab, new Object[] {"aa"}, new Object[] {"a"}); + DataFile dataFileB = dataFile(tab, new Object[] {"bb"}, new Object[] {"b"}); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + // Add position deletes for both partitions + Pair>, DeleteFile> deletesA = + deleteFile(tab, dataFileA, new Object[] {"aa"}, new Object[] {"a"}); + Pair>, DeleteFile> deletesB = + deleteFile(tab, dataFileA, new Object[] {"bb"}, new Object[] {"b"}); + tab.newRowDelta().addDeletes(deletesA.second()).addDeletes(deletesB.second()).commit(); + + // Prepare expected values + GenericRecord partitionRecordTemplate = GenericRecord.create(tab.spec().partitionType()); + Record partitionA = partitionRecordTemplate.copy("data_trunc", "a"); + Record partitionB = partitionRecordTemplate.copy("data_trunc", "b"); + StructLikeSet expectedA = + expected(tab, deletesA.first(), partitionA, deletesA.second().location()); + StructLikeSet expectedB = + expected(tab, deletesB.first(), partitionB, deletesB.second().location()); + StructLikeSet allExpected = StructLikeSet.create(deletesTable.schema().asStruct()); + allExpected.addAll(expectedA); + allExpected.addAll(expectedB); + + // Select deletes from all partitions + StructLikeSet actual = actual(tableName, tab); + assertThat(actual) + .as("Position Delete table should contain expected rows") + .isEqualTo(allExpected); + + // Select deletes from one partition + StructLikeSet actual2 = actual(tableName, tab, "partition.data_trunc = 'a' AND pos >= 0"); + + assertThat(actual2) + .as("Position Delete table should contain expected rows") + .isEqualTo(expectedA); + dropTable(tableName); + } + + @TestTemplate + public void testPartitionEvolutionReplace() throws Exception { + // Create table with spec (data) + String tableName = "partition_evolution"; + PartitionSpec originalSpec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table tab = createTable(tableName, SCHEMA, originalSpec); + int dataSpec = tab.spec().specId(); + + // Add files with old spec (data) + DataFile dataFileA = dataFile(tab, "a"); + DataFile dataFileB = dataFile(tab, "b"); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + Pair>, DeleteFile> deletesA = deleteFile(tab, dataFileA, "a"); + Pair>, DeleteFile> deletesB = deleteFile(tab, dataFileA, "b"); + tab.newRowDelta().addDeletes(deletesA.second()).addDeletes(deletesB.second()).commit(); + + // Switch partition spec from (data) to (id) + tab.updateSpec().removeField("data").addField("id").commit(); + + // Add data and delete files with new spec (id) + DataFile dataFile10 = dataFile(tab, 10); + DataFile dataFile99 = dataFile(tab, 99); + tab.newAppend().appendFile(dataFile10).appendFile(dataFile99).commit(); + + Pair>, DeleteFile> deletes10 = deleteFile(tab, dataFile10, 10); + Pair>, DeleteFile> deletes99 = deleteFile(tab, dataFile10, 99); + tab.newRowDelta().addDeletes(deletes10.second()).addDeletes(deletes99.second()).commit(); + + // Query partition of old spec + GenericRecord partitionRecordTemplate = GenericRecord.create(Partitioning.partitionType(tab)); + Record partitionA = partitionRecordTemplate.copy("data", "a"); + StructLikeSet expectedA = + expected(tab, deletesA.first(), partitionA, dataSpec, deletesA.second().location()); + StructLikeSet actualA = actual(tableName, tab, "partition.data = 'a' AND pos >= 0"); + assertThat(actualA) + .as("Position Delete table should contain expected rows") + .isEqualTo(expectedA); + + // Query partition of new spec + Record partition10 = partitionRecordTemplate.copy("id", 10); + StructLikeSet expected10 = + expected( + tab, + deletes10.first(), + partition10, + tab.spec().specId(), + deletes10.second().location()); + StructLikeSet actual10 = actual(tableName, tab, "partition.id = 10 AND pos >= 0"); + + assertThat(actual10) + .as("Position Delete table should contain expected rows") + .isEqualTo(expected10); + dropTable(tableName); + } + + @TestTemplate + public void testPartitionEvolutionAdd() throws Exception { + // Create unpartitioned table + String tableName = "partition_evolution_add"; + Table tab = createTable(tableName, SCHEMA, PartitionSpec.unpartitioned()); + int specId0 = tab.spec().specId(); + + // Add files with unpartitioned spec + DataFile dataFileUnpartitioned = dataFile(tab); + tab.newAppend().appendFile(dataFileUnpartitioned).commit(); + Pair>, DeleteFile> deletesUnpartitioned = + deleteFile(tab, dataFileUnpartitioned); + tab.newRowDelta().addDeletes(deletesUnpartitioned.second()).commit(); + + // Switch partition spec to (data) + tab.updateSpec().addField("data").commit(); + int specId1 = tab.spec().specId(); + + // Add files with new spec (data) + DataFile dataFileA = dataFile(tab, "a"); + DataFile dataFileB = dataFile(tab, "b"); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + Pair>, DeleteFile> deletesA = deleteFile(tab, dataFileA, "a"); + Pair>, DeleteFile> deletesB = deleteFile(tab, dataFileB, "b"); + tab.newRowDelta().addDeletes(deletesA.second()).addDeletes(deletesB.second()).commit(); + + // Select deletes from new spec (data) + GenericRecord partitionRecordTemplate = GenericRecord.create(Partitioning.partitionType(tab)); + Record partitionA = partitionRecordTemplate.copy("data", "a"); + StructLikeSet expectedA = + expected(tab, deletesA.first(), partitionA, specId1, deletesA.second().location()); + StructLikeSet actualA = actual(tableName, tab, "partition.data = 'a' AND pos >= 0"); + assertThat(actualA) + .as("Position Delete table should contain expected rows") + .isEqualTo(expectedA); + + // Select deletes from 'unpartitioned' data + Record unpartitionedRecord = partitionRecordTemplate.copy("data", null); + StructLikeSet expectedUnpartitioned = + expected( + tab, + deletesUnpartitioned.first(), + unpartitionedRecord, + specId0, + deletesUnpartitioned.second().location()); + StructLikeSet actualUnpartitioned = + actual(tableName, tab, "partition.data IS NULL and pos >= 0"); + + assertThat(actualUnpartitioned) + .as("Position Delete table should contain expected rows") + .isEqualTo(expectedUnpartitioned); + dropTable(tableName); + } + + @TestTemplate + public void testPartitionEvolutionRemove() throws Exception { + // Create table with spec (data) + String tableName = "partition_evolution_remove"; + PartitionSpec originalSpec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table tab = createTable(tableName, SCHEMA, originalSpec); + int specId0 = tab.spec().specId(); + + // Add files with spec (data) + DataFile dataFileA = dataFile(tab, "a"); + DataFile dataFileB = dataFile(tab, "b"); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + Pair>, DeleteFile> deletesA = deleteFile(tab, dataFileA, "a"); + Pair>, DeleteFile> deletesB = deleteFile(tab, dataFileB, "b"); + tab.newRowDelta().addDeletes(deletesA.second()).addDeletes(deletesB.second()).commit(); + + // Remove partition field + tab.updateSpec().removeField("data").commit(); + int specId1 = tab.spec().specId(); + + // Add unpartitioned files + DataFile dataFileUnpartitioned = dataFile(tab); + tab.newAppend().appendFile(dataFileUnpartitioned).commit(); + Pair>, DeleteFile> deletesUnpartitioned = + deleteFile(tab, dataFileUnpartitioned); + tab.newRowDelta().addDeletes(deletesUnpartitioned.second()).commit(); + + // Select deletes from (data) spec + GenericRecord partitionRecordTemplate = GenericRecord.create(Partitioning.partitionType(tab)); + Record partitionA = partitionRecordTemplate.copy("data", "a"); + StructLikeSet expectedA = + expected(tab, deletesA.first(), partitionA, specId0, deletesA.second().location()); + StructLikeSet actualA = actual(tableName, tab, "partition.data = 'a' AND pos >= 0"); + assertThat(actualA) + .as("Position Delete table should contain expected rows") + .isEqualTo(expectedA); + + // Select deletes from 'unpartitioned' spec + Record unpartitionedRecord = partitionRecordTemplate.copy("data", null); + StructLikeSet expectedUnpartitioned = + expected( + tab, + deletesUnpartitioned.first(), + unpartitionedRecord, + specId1, + deletesUnpartitioned.second().location()); + StructLikeSet actualUnpartitioned = + actual(tableName, tab, "partition.data IS NULL and pos >= 0"); + + assertThat(actualUnpartitioned) + .as("Position Delete table should contain expected rows") + .isEqualTo(expectedUnpartitioned); + dropTable(tableName); + } + + @TestTemplate + public void testSpecIdFilter() throws Exception { + // Create table with spec (data) + String tableName = "spec_id_filter"; + Table tab = createTable(tableName, SCHEMA, PartitionSpec.unpartitioned()); + int unpartitionedSpec = tab.spec().specId(); + + // Add data file and delete + DataFile dataFileUnpartitioned = dataFile(tab); + tab.newAppend().appendFile(dataFileUnpartitioned).commit(); + Pair>, DeleteFile> deletesUnpartitioned = + deleteFile(tab, dataFileUnpartitioned); + tab.newRowDelta().addDeletes(deletesUnpartitioned.second()).commit(); + + // Switch partition spec to (data) and add files + tab.updateSpec().addField("data").commit(); + int dataSpec = tab.spec().specId(); + + DataFile dataFileA = dataFile(tab, "a"); + DataFile dataFileB = dataFile(tab, "b"); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + Pair>, DeleteFile> deletesA = deleteFile(tab, dataFileA, "a"); + Pair>, DeleteFile> deletesB = deleteFile(tab, dataFileB, "b"); + tab.newRowDelta().addDeletes(deletesA.second()).addDeletes(deletesB.second()).commit(); + + // Select deletes from 'unpartitioned' + GenericRecord partitionRecordTemplate = GenericRecord.create(Partitioning.partitionType(tab)); + StructLikeSet expectedUnpartitioned = + expected( + tab, + deletesUnpartitioned.first(), + partitionRecordTemplate, + unpartitionedSpec, + deletesUnpartitioned.second().location()); + StructLikeSet actualUnpartitioned = + actual(tableName, tab, String.format("spec_id = %d", unpartitionedSpec)); + assertThat(actualUnpartitioned) + .as("Position Delete table should contain expected rows") + .isEqualTo(expectedUnpartitioned); + + // Select deletes from 'data' partition spec + StructLike partitionA = partitionRecordTemplate.copy("data", "a"); + StructLike partitionB = partitionRecordTemplate.copy("data", "b"); + StructLikeSet expected = + expected(tab, deletesA.first(), partitionA, dataSpec, deletesA.second().location()); + expected.addAll( + expected(tab, deletesB.first(), partitionB, dataSpec, deletesB.second().location())); + + StructLikeSet actual = actual(tableName, tab, String.format("spec_id = %d", dataSpec)); + assertThat(actual).as("Position Delete table should contain expected rows").isEqualTo(expected); + dropTable(tableName); + } + + @TestTemplate + public void testSchemaEvolutionAdd() throws Exception { + // Create table with original schema + String tableName = "schema_evolution_add"; + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table tab = createTable(tableName, SCHEMA, spec); + + // Add files with original schema + DataFile dataFileA = dataFile(tab, "a"); + DataFile dataFileB = dataFile(tab, "b"); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + Pair>, DeleteFile> deletesA = deleteFile(tab, dataFileA, "a"); + Pair>, DeleteFile> deletesB = deleteFile(tab, dataFileB, "b"); + tab.newRowDelta().addDeletes(deletesA.second()).addDeletes(deletesB.second()).commit(); + + // Add files with new schema + tab.updateSchema() + .addColumn("new_col_1", Types.IntegerType.get()) + .addColumn("new_col_2", Types.IntegerType.get()) + .commit(); + + // Add files with new schema + DataFile dataFileC = dataFile(tab, "c"); + DataFile dataFileD = dataFile(tab, "d"); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + Pair>, DeleteFile> deletesC = deleteFile(tab, dataFileC, "c"); + Pair>, DeleteFile> deletesD = deleteFile(tab, dataFileD, "d"); + tab.newRowDelta().addDeletes(deletesC.second()).addDeletes(deletesD.second()).commit(); + + // Select deletes from old schema + GenericRecord partitionRecordTemplate = GenericRecord.create(Partitioning.partitionType(tab)); + Record partitionA = partitionRecordTemplate.copy("data", "a"); + // pad expected delete rows with null values for new columns + List> expectedDeletesA = deletesA.first(); + expectedDeletesA.forEach( + d -> { + GenericRecord nested = d.get(2, GenericRecord.class); + GenericRecord padded = GenericRecord.create(tab.schema().asStruct()); + padded.set(0, nested.get(0)); + padded.set(1, nested.get(1)); + padded.set(2, null); + padded.set(3, null); + d.set(2, padded); + }); + StructLikeSet expectedA = + expected(tab, expectedDeletesA, partitionA, deletesA.second().location()); + StructLikeSet actualA = actual(tableName, tab, "partition.data = 'a' AND pos >= 0"); + assertThat(actualA) + .as("Position Delete table should contain expected rows") + .isEqualTo(expectedA); + + // Select deletes from new schema + Record partitionC = partitionRecordTemplate.copy("data", "c"); + StructLikeSet expectedC = + expected(tab, deletesC.first(), partitionC, deletesC.second().location()); + StructLikeSet actualC = actual(tableName, tab, "partition.data = 'c' and pos >= 0"); + + assertThat(actualC) + .as("Position Delete table should contain expected rows") + .isEqualTo(expectedC); + dropTable(tableName); + } + + @TestTemplate + public void testSchemaEvolutionRemove() throws Exception { + // Create table with original schema + String tableName = "schema_evolution_remove"; + Schema oldSchema = + new Schema( + Types.NestedField.required(1, "id", Types.IntegerType.get()), + Types.NestedField.required(2, "data", Types.StringType.get()), + Types.NestedField.optional(3, "new_col_1", Types.IntegerType.get()), + Types.NestedField.optional(4, "new_col_2", Types.IntegerType.get())); + PartitionSpec spec = PartitionSpec.builderFor(oldSchema).identity("data").build(); + Table tab = createTable(tableName, oldSchema, spec); + + // Add files with original schema + DataFile dataFileA = dataFile(tab, "a"); + DataFile dataFileB = dataFile(tab, "b"); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + Pair>, DeleteFile> deletesA = deleteFile(tab, dataFileA, "a"); + Pair>, DeleteFile> deletesB = deleteFile(tab, dataFileB, "b"); + tab.newRowDelta().addDeletes(deletesA.second()).addDeletes(deletesB.second()).commit(); + + // Add files with new schema + tab.updateSchema().deleteColumn("new_col_1").deleteColumn("new_col_2").commit(); + + // Add files with new schema + DataFile dataFileC = dataFile(tab, "c"); + DataFile dataFileD = dataFile(tab, "d"); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + Pair>, DeleteFile> deletesC = deleteFile(tab, dataFileC, "c"); + Pair>, DeleteFile> deletesD = deleteFile(tab, dataFileD, "d"); + tab.newRowDelta().addDeletes(deletesC.second()).addDeletes(deletesD.second()).commit(); + + // Select deletes from old schema + GenericRecord partitionRecordTemplate = GenericRecord.create(Partitioning.partitionType(tab)); + Record partitionA = partitionRecordTemplate.copy("data", "a"); + // remove deleted columns from expected result + List> expectedDeletesA = deletesA.first(); + expectedDeletesA.forEach( + d -> { + GenericRecord nested = d.get(2, GenericRecord.class); + GenericRecord padded = GenericRecord.create(tab.schema().asStruct()); + padded.set(0, nested.get(0)); + padded.set(1, nested.get(1)); + d.set(2, padded); + }); + StructLikeSet expectedA = + expected(tab, expectedDeletesA, partitionA, deletesA.second().location()); + StructLikeSet actualA = actual(tableName, tab, "partition.data = 'a' AND pos >= 0"); + assertThat(actualA) + .as("Position Delete table should contain expected rows") + .isEqualTo(expectedA); + + // Select deletes from new schema + Record partitionC = partitionRecordTemplate.copy("data", "c"); + StructLikeSet expectedC = + expected(tab, deletesC.first(), partitionC, deletesC.second().location()); + StructLikeSet actualC = actual(tableName, tab, "partition.data = 'c' and pos >= 0"); + + assertThat(actualC) + .as("Position Delete table should contain expected rows") + .isEqualTo(expectedC); + dropTable(tableName); + } + + @TestTemplate + public void testWrite() throws IOException, NoSuchTableException { + String tableName = "test_write"; + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table tab = createTable(tableName, SCHEMA, spec); + + DataFile dataFileA = dataFile(tab, "a"); + DataFile dataFileB = dataFile(tab, "b"); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + // Add position deletes for both partitions + Pair>, DeleteFile> deletesA = deleteFile(tab, dataFileA, "a"); + Pair>, DeleteFile> deletesB = deleteFile(tab, dataFileB, "b"); + + tab.newRowDelta().addDeletes(deletesA.second()).addDeletes(deletesB.second()).commit(); + + Table posDeletesTable = + MetadataTableUtils.createMetadataTableInstance(tab, MetadataTableType.POSITION_DELETES); + String posDeletesTableName = catalogName + ".default." + tableName + ".position_deletes"; + for (String partValue : ImmutableList.of("a", "b")) { + try (CloseableIterable tasks = tasks(posDeletesTable, "data", partValue)) { + String fileSetID = UUID.randomUUID().toString(); + stageTask(tab, fileSetID, tasks); + + Dataset scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, fileSetID) + .option(SparkReadOptions.FILE_OPEN_COST, Integer.MAX_VALUE) + .load(posDeletesTableName); + + assertThat(scanDF.javaRDD().getNumPartitions()).isEqualTo(1); + scanDF + .writeTo(posDeletesTableName) + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, fileSetID) + .append(); + + commit(tab, posDeletesTable, fileSetID, 1); + } + } + + // Prepare expected values (without 'delete_file_path' as these have been rewritten) + GenericRecord partitionRecordTemplate = GenericRecord.create(tab.spec().partitionType()); + Record partitionA = partitionRecordTemplate.copy("data", "a"); + Record partitionB = partitionRecordTemplate.copy("data", "b"); + StructLikeSet expectedA = expected(tab, deletesA.first(), partitionA, null); + StructLikeSet expectedB = expected(tab, deletesB.first(), partitionB, null); + StructLikeSet allExpected = + StructLikeSet.create( + TypeUtil.selectNot( + posDeletesTable.schema(), ImmutableSet.of(MetadataColumns.FILE_PATH_COLUMN_ID)) + .asStruct()); + allExpected.addAll(expectedA); + allExpected.addAll(expectedB); + + // Compare values without 'delete_file_path' as these have been rewritten + StructLikeSet actual = actual(tableName, tab, null, NON_PATH_COLS); + assertThat(actual) + .as("Position Delete table should contain expected rows") + .isEqualTo(allExpected); + dropTable(tableName); + } + + @TestTemplate + public void testWriteUnpartitionedNullRows() throws Exception { + String tableName = "write_null_rows"; + Table tab = createTable(tableName, SCHEMA, PartitionSpec.unpartitioned()); + + DataFile dFile = dataFile(tab); + tab.newAppend().appendFile(dFile).commit(); + + List> deletes = Lists.newArrayList(); + deletes.add(Pair.of(dFile.location(), 0L)); + deletes.add(Pair.of(dFile.location(), 1L)); + Pair posDeletes = + FileHelpers.writeDeleteFile( + tab, + Files.localOutput(File.createTempFile("junit", null, temp.toFile())), + TestHelpers.Row.of(0), + deletes); + tab.newRowDelta().addDeletes(posDeletes.first()).commit(); + + Table posDeletesTable = + MetadataTableUtils.createMetadataTableInstance(tab, MetadataTableType.POSITION_DELETES); + String posDeletesTableName = catalogName + ".default." + tableName + ".position_deletes"; + try (CloseableIterable tasks = posDeletesTable.newBatchScan().planFiles()) { + String fileSetID = UUID.randomUUID().toString(); + stageTask(tab, fileSetID, tasks); + + Dataset scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, fileSetID) + .option(SparkReadOptions.FILE_OPEN_COST, Integer.MAX_VALUE) + .load(posDeletesTableName); + assertThat(scanDF.javaRDD().getNumPartitions()).isEqualTo(1); + scanDF + .writeTo(posDeletesTableName) + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, fileSetID) + .append(); + + commit(tab, posDeletesTable, fileSetID, 1); + } + + // Compare values without 'delete_file_path' as these have been rewritten + StructLikeSet actual = + actual(tableName, tab, null, ImmutableList.of("file_path", "pos", "row", "spec_id")); + + List> expectedDeletes = + Lists.newArrayList( + positionDelete(dFile.location(), 0L), positionDelete(dFile.location(), 1L)); + StructLikeSet expected = expected(tab, expectedDeletes, null, null); + + assertThat(actual).as("Position Delete table should contain expected rows").isEqualTo(expected); + dropTable(tableName); + } + + @TestTemplate + public void testWriteMixedRows() throws Exception { + String tableName = "write_mixed_rows"; + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table tab = createTable(tableName, SCHEMA, spec); + + DataFile dataFileA = dataFile(tab, "a"); + DataFile dataFileB = dataFile(tab, "b"); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + // Add a delete file with row and without row + List> deletes = Lists.newArrayList(); + deletes.add(Pair.of(dataFileA.location(), 0L)); + deletes.add(Pair.of(dataFileA.location(), 1L)); + Pair deletesWithoutRow = + FileHelpers.writeDeleteFile( + tab, + Files.localOutput(File.createTempFile("junit", null, temp.toFile())), + TestHelpers.Row.of("a"), + deletes); + + Pair>, DeleteFile> deletesWithRow = deleteFile(tab, dataFileB, "b"); + + tab.newRowDelta() + .addDeletes(deletesWithoutRow.first()) + .addDeletes(deletesWithRow.second()) + .commit(); + + // rewrite delete files + Table posDeletesTable = + MetadataTableUtils.createMetadataTableInstance(tab, MetadataTableType.POSITION_DELETES); + String posDeletesTableName = catalogName + ".default." + tableName + ".position_deletes"; + for (String partValue : ImmutableList.of("a", "b")) { + try (CloseableIterable tasks = tasks(posDeletesTable, "data", partValue)) { + String fileSetID = UUID.randomUUID().toString(); + stageTask(tab, fileSetID, tasks); + + Dataset scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, fileSetID) + .load(posDeletesTableName); + assertThat(scanDF.javaRDD().getNumPartitions()).isEqualTo(1); + scanDF + .writeTo(posDeletesTableName) + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, fileSetID) + .append(); + + commit(tab, posDeletesTable, fileSetID, 1); + } + } + + // Compare values without 'delete_file_path' as these have been rewritten + StructLikeSet actual = + actual( + tableName, + tab, + null, + ImmutableList.of("file_path", "pos", "row", "partition", "spec_id")); + + // Prepare expected values + GenericRecord partitionRecordTemplate = GenericRecord.create(tab.spec().partitionType()); + Record partitionA = partitionRecordTemplate.copy("data", "a"); + Record partitionB = partitionRecordTemplate.copy("data", "b"); + StructLikeSet allExpected = + StructLikeSet.create( + TypeUtil.selectNot( + posDeletesTable.schema(), ImmutableSet.of(MetadataColumns.FILE_PATH_COLUMN_ID)) + .asStruct()); + allExpected.addAll( + expected( + tab, + Lists.newArrayList( + positionDelete(dataFileA.location(), 0L), positionDelete(dataFileA.location(), 1L)), + partitionA, + null)); + allExpected.addAll(expected(tab, deletesWithRow.first(), partitionB, null)); + + assertThat(actual) + .as("Position Delete table should contain expected rows") + .isEqualTo(allExpected); + dropTable(tableName); + } + + @TestTemplate + public void testWritePartitionEvolutionAdd() throws Exception { + // Create unpartitioned table + String tableName = "write_partition_evolution_add"; + Table tab = createTable(tableName, SCHEMA, PartitionSpec.unpartitioned()); + int specId0 = tab.spec().specId(); + + // Add files with unpartitioned spec + DataFile dataFileUnpartitioned = dataFile(tab); + tab.newAppend().appendFile(dataFileUnpartitioned).commit(); + Pair>, DeleteFile> deletesUnpartitioned = + deleteFile(tab, dataFileUnpartitioned); + tab.newRowDelta().addDeletes(deletesUnpartitioned.second()).commit(); + + // Switch partition spec to (data) + tab.updateSpec().addField("data").commit(); + int specId1 = tab.spec().specId(); + + // Add files with new spec (data) + DataFile dataFileA = dataFile(tab, "a"); + DataFile dataFileB = dataFile(tab, "b"); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + Pair>, DeleteFile> deletesA = deleteFile(tab, dataFileA, "a"); + Pair>, DeleteFile> deletesB = deleteFile(tab, dataFileB, "b"); + tab.newRowDelta().addDeletes(deletesA.second()).addDeletes(deletesB.second()).commit(); + + Table posDeletesTable = + MetadataTableUtils.createMetadataTableInstance(tab, MetadataTableType.POSITION_DELETES); + String posDeletesTableName = catalogName + ".default." + tableName + ".position_deletes"; + + // Read/write back unpartitioned data + try (CloseableIterable tasks = + posDeletesTable.newBatchScan().filter(Expressions.isNull("partition.data")).planFiles()) { + String fileSetID = UUID.randomUUID().toString(); + stageTask(tab, fileSetID, tasks); + + Dataset scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, fileSetID) + .option(SparkReadOptions.FILE_OPEN_COST, Integer.MAX_VALUE) + .load(posDeletesTableName); + assertThat(scanDF.javaRDD().getNumPartitions()).isEqualTo(1); + scanDF + .writeTo(posDeletesTableName) + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, fileSetID) + .append(); + + commit(tab, posDeletesTable, fileSetID, 1); + } + + // Select deletes from unpartitioned data + // Compare values without 'delete_file_path' as these have been rewritten + GenericRecord partitionRecordTemplate = GenericRecord.create(Partitioning.partitionType(tab)); + Record unpartitionedRecord = partitionRecordTemplate.copy("data", null); + StructLikeSet expectedUnpartitioned = + expected(tab, deletesUnpartitioned.first(), unpartitionedRecord, specId0, null); + StructLikeSet actualUnpartitioned = + actual(tableName, tab, "partition.data IS NULL", NON_PATH_COLS); + assertThat(actualUnpartitioned) + .as("Position Delete table should contain expected rows") + .isEqualTo(expectedUnpartitioned); + + // Read/write back new partition spec (data) + for (String partValue : ImmutableList.of("a", "b")) { + try (CloseableIterable tasks = tasks(posDeletesTable, "data", partValue)) { + String fileSetID = UUID.randomUUID().toString(); + stageTask(tab, fileSetID, tasks); + + Dataset scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, fileSetID) + .option(SparkReadOptions.FILE_OPEN_COST, Integer.MAX_VALUE) + .load(posDeletesTableName); + assertThat(scanDF.javaRDD().getNumPartitions()).isEqualTo(1); + scanDF + .writeTo(posDeletesTableName) + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, fileSetID) + .append(); + + // commit the rewrite + commit(tab, posDeletesTable, fileSetID, 1); + } + } + + // Select deletes from new spec (data) + Record partitionA = partitionRecordTemplate.copy("data", "a"); + Record partitionB = partitionRecordTemplate.copy("data", "b"); + StructLikeSet expectedAll = + StructLikeSet.create( + TypeUtil.selectNot( + posDeletesTable.schema(), ImmutableSet.of(MetadataColumns.FILE_PATH_COLUMN_ID)) + .asStruct()); + expectedAll.addAll(expected(tab, deletesA.first(), partitionA, specId1, null)); + expectedAll.addAll(expected(tab, deletesB.first(), partitionB, specId1, null)); + StructLikeSet actualAll = + actual(tableName, tab, "partition.data = 'a' OR partition.data = 'b'", NON_PATH_COLS); + assertThat(actualAll) + .as("Position Delete table should contain expected rows") + .isEqualTo(expectedAll); + dropTable(tableName); + } + + @TestTemplate + public void testWritePartitionEvolutionDisallowed() throws Exception { + // Create unpartitioned table + String tableName = "write_partition_evolution_write"; + Table tab = createTable(tableName, SCHEMA, PartitionSpec.unpartitioned()); + + // Add files with unpartitioned spec + DataFile dataFileUnpartitioned = dataFile(tab); + tab.newAppend().appendFile(dataFileUnpartitioned).commit(); + Pair>, DeleteFile> deletesUnpartitioned = + deleteFile(tab, dataFileUnpartitioned); + tab.newRowDelta().addDeletes(deletesUnpartitioned.second()).commit(); + + Table posDeletesTable = + MetadataTableUtils.createMetadataTableInstance(tab, MetadataTableType.POSITION_DELETES); + String posDeletesTableName = catalogName + ".default." + tableName + ".position_deletes"; + + Dataset scanDF; + String fileSetID = UUID.randomUUID().toString(); + try (CloseableIterable tasks = posDeletesTable.newBatchScan().planFiles()) { + stageTask(tab, fileSetID, tasks); + + scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, fileSetID) + .option(SparkReadOptions.FILE_OPEN_COST, Integer.MAX_VALUE) + .load(posDeletesTableName); + assertThat(scanDF.javaRDD().getNumPartitions()).isEqualTo(1); + + // Add partition field to render the original un-partitioned dataset un-commitable + tab.updateSpec().addField("data").commit(); + } + + assertThatThrownBy( + () -> + scanDF + .writeTo(posDeletesTableName) + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, fileSetID) + .append()) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining( + "[INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA] Cannot write incompatible data for the table `" + + catalogName + + "`.`default`.`" + + tableName + + "`.`position_deletes`" + + ": Cannot find data for the output column `partition`."); + + dropTable(tableName); + } + + @TestTemplate + public void testWriteSchemaEvolutionAdd() throws Exception { + // Create table with original schema + String tableName = "write_schema_evolution_add"; + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table tab = createTable(tableName, SCHEMA, spec); + + // Add files with original schema + DataFile dataFileA = dataFile(tab, "a"); + DataFile dataFileB = dataFile(tab, "b"); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + Pair>, DeleteFile> deletesA = deleteFile(tab, dataFileA, "a"); + Pair>, DeleteFile> deletesB = deleteFile(tab, dataFileB, "b"); + tab.newRowDelta().addDeletes(deletesA.second()).addDeletes(deletesB.second()).commit(); + + // Add files with new schema + tab.updateSchema() + .addColumn("new_col_1", Types.IntegerType.get()) + .addColumn("new_col_2", Types.IntegerType.get()) + .commit(); + + // Add files with new schema + DataFile dataFileC = dataFile(tab, "c"); + DataFile dataFileD = dataFile(tab, "d"); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + Pair>, DeleteFile> deletesC = deleteFile(tab, dataFileC, "c"); + Pair>, DeleteFile> deletesD = deleteFile(tab, dataFileD, "d"); + tab.newRowDelta().addDeletes(deletesC.second()).addDeletes(deletesD.second()).commit(); + + Table posDeletesTable = + MetadataTableUtils.createMetadataTableInstance(tab, MetadataTableType.POSITION_DELETES); + String posDeletesTableName = catalogName + ".default." + tableName + ".position_deletes"; + + // rewrite files of old schema + try (CloseableIterable tasks = tasks(posDeletesTable, "data", "a")) { + String fileSetID = UUID.randomUUID().toString(); + stageTask(tab, fileSetID, tasks); + + Dataset scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, fileSetID) + .option(SparkReadOptions.FILE_OPEN_COST, Integer.MAX_VALUE) + .load(posDeletesTableName); + + assertThat(scanDF.javaRDD().getNumPartitions()).isEqualTo(1); + scanDF + .writeTo(posDeletesTableName) + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, fileSetID) + .append(); + + commit(tab, posDeletesTable, fileSetID, 1); + } + + // Select deletes from old schema + GenericRecord partitionRecordTemplate = GenericRecord.create(Partitioning.partitionType(tab)); + Record partitionA = partitionRecordTemplate.copy("data", "a"); + // pad expected delete rows with null values for new columns + List> expectedDeletesA = deletesA.first(); + expectedDeletesA.forEach( + d -> { + GenericRecord nested = d.get(2, GenericRecord.class); + GenericRecord padded = GenericRecord.create(tab.schema().asStruct()); + padded.set(0, nested.get(0)); + padded.set(1, nested.get(1)); + padded.set(2, null); + padded.set(3, null); + d.set(2, padded); + }); + StructLikeSet expectedA = expected(tab, expectedDeletesA, partitionA, null); + StructLikeSet actualA = actual(tableName, tab, "partition.data = 'a'", NON_PATH_COLS); + assertThat(actualA) + .as("Position Delete table should contain expected rows") + .isEqualTo(expectedA); + + // rewrite files of new schema + try (CloseableIterable tasks = tasks(posDeletesTable, "data", "c")) { + String fileSetID = UUID.randomUUID().toString(); + stageTask(tab, fileSetID, tasks); + + Dataset scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, fileSetID) + .option(SparkReadOptions.FILE_OPEN_COST, Integer.MAX_VALUE) + .load(posDeletesTableName); + + assertThat(scanDF.javaRDD().getNumPartitions()).isEqualTo(1); + scanDF + .writeTo(posDeletesTableName) + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, fileSetID) + .append(); + + commit(tab, posDeletesTable, fileSetID, 1); + } + + // Select deletes from new schema + Record partitionC = partitionRecordTemplate.copy("data", "c"); + StructLikeSet expectedC = expected(tab, deletesC.first(), partitionC, null); + StructLikeSet actualC = actual(tableName, tab, "partition.data = 'c'", NON_PATH_COLS); + + assertThat(actualC) + .as("Position Delete table should contain expected rows") + .isEqualTo(expectedC); + dropTable(tableName); + } + + @TestTemplate + public void testWriteSchemaEvolutionRemove() throws Exception { + // Create table with original schema + String tableName = "write_schema_evolution_remove"; + Schema oldSchema = + new Schema( + Types.NestedField.required(1, "id", Types.IntegerType.get()), + Types.NestedField.required(2, "data", Types.StringType.get()), + Types.NestedField.optional(3, "new_col_1", Types.IntegerType.get()), + Types.NestedField.optional(4, "new_col_2", Types.IntegerType.get())); + PartitionSpec spec = PartitionSpec.builderFor(oldSchema).identity("data").build(); + Table tab = createTable(tableName, oldSchema, spec); + + // Add files with original schema + DataFile dataFileA = dataFile(tab, "a"); + DataFile dataFileB = dataFile(tab, "b"); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + Pair>, DeleteFile> deletesA = deleteFile(tab, dataFileA, "a"); + Pair>, DeleteFile> deletesB = deleteFile(tab, dataFileB, "b"); + tab.newRowDelta().addDeletes(deletesA.second()).addDeletes(deletesB.second()).commit(); + + // Add files with new schema + tab.updateSchema().deleteColumn("new_col_1").deleteColumn("new_col_2").commit(); + + // Add files with new schema + DataFile dataFileC = dataFile(tab, "c"); + DataFile dataFileD = dataFile(tab, "d"); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + Pair>, DeleteFile> deletesC = deleteFile(tab, dataFileC, "c"); + Pair>, DeleteFile> deletesD = deleteFile(tab, dataFileD, "d"); + tab.newRowDelta().addDeletes(deletesC.second()).addDeletes(deletesD.second()).commit(); + + Table posDeletesTable = + MetadataTableUtils.createMetadataTableInstance(tab, MetadataTableType.POSITION_DELETES); + String posDeletesTableName = catalogName + ".default." + tableName + ".position_deletes"; + + // rewrite files + for (String partValue : ImmutableList.of("a", "b", "c", "d")) { + try (CloseableIterable tasks = tasks(posDeletesTable, "data", partValue)) { + String fileSetID = UUID.randomUUID().toString(); + stageTask(tab, fileSetID, tasks); + + Dataset scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, fileSetID) + .option(SparkReadOptions.FILE_OPEN_COST, Integer.MAX_VALUE) + .load(posDeletesTableName); + assertThat(scanDF.javaRDD().getNumPartitions()).isEqualTo(1); + scanDF + .writeTo(posDeletesTableName) + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, fileSetID) + .append(); + + commit(tab, posDeletesTable, fileSetID, 1); + } + } + + // Select deletes from old schema + GenericRecord partitionRecordTemplate = GenericRecord.create(Partitioning.partitionType(tab)); + Record partitionA = partitionRecordTemplate.copy("data", "a"); + // remove deleted columns from expected result + List> expectedDeletesA = deletesA.first(); + expectedDeletesA.forEach( + d -> { + GenericRecord nested = d.get(2, GenericRecord.class); + GenericRecord padded = GenericRecord.create(tab.schema().asStruct()); + padded.set(0, nested.get(0)); + padded.set(1, nested.get(1)); + d.set(2, padded); + }); + StructLikeSet expectedA = expected(tab, expectedDeletesA, partitionA, null); + StructLikeSet actualA = actual(tableName, tab, "partition.data = 'a'", NON_PATH_COLS); + assertThat(actualA) + .as("Position Delete table should contain expected rows") + .isEqualTo(expectedA); + + // Select deletes from new schema + Record partitionC = partitionRecordTemplate.copy("data", "c"); + StructLikeSet expectedC = expected(tab, deletesC.first(), partitionC, null); + StructLikeSet actualC = actual(tableName, tab, "partition.data = 'c'", NON_PATH_COLS); + + assertThat(actualC) + .as("Position Delete table should contain expected rows") + .isEqualTo(expectedC); + dropTable(tableName); + } + + @TestTemplate + public void testNormalWritesNotAllowed() throws IOException { + String tableName = "test_normal_write_not_allowed"; + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table tab = createTable(tableName, SCHEMA, spec); + + DataFile dataFileA = dataFile(tab, "a"); + tab.newAppend().appendFile(dataFileA).commit(); + + Pair>, DeleteFile> deletesA = deleteFile(tab, dataFileA, "a"); + tab.newRowDelta().addDeletes(deletesA.second()).commit(); + + String posDeletesTableName = catalogName + ".default." + tableName + ".position_deletes"; + + Dataset scanDF = spark.read().format("iceberg").load(posDeletesTableName); + + assertThatThrownBy(() -> scanDF.writeTo(posDeletesTableName).append()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Can only write to " + posDeletesTableName + " via actions"); + + dropTable(tableName); + } + + private StructLikeSet actual(String tableName, Table table) { + return actual(tableName, table, null, null); + } + + private StructLikeSet actual(String tableName, Table table, String filter) { + return actual(tableName, table, filter, null); + } + + private StructLikeSet actual(String tableName, Table table, String filter, List cols) { + Dataset df = + spark + .read() + .format("iceberg") + .load(catalogName + ".default." + tableName + ".position_deletes"); + if (filter != null) { + df = df.filter(filter); + } + if (cols != null) { + df = df.select(cols.get(0), cols.subList(1, cols.size()).toArray(new String[0])); + } + Table deletesTable = + MetadataTableUtils.createMetadataTableInstance(table, MetadataTableType.POSITION_DELETES); + Types.StructType projection = deletesTable.schema().asStruct(); + if (cols != null) { + projection = + Types.StructType.of( + projection.fields().stream() + .filter(f -> cols.contains(f.name())) + .collect(Collectors.toList())); + } + Types.StructType finalProjection = projection; + StructLikeSet set = StructLikeSet.create(projection); + df.collectAsList() + .forEach( + row -> { + SparkStructLike rowWrapper = new SparkStructLike(finalProjection); + set.add(rowWrapper.wrap(row)); + }); + + return set; + } + + protected Table createTable(String name, Schema schema, PartitionSpec spec) { + Map properties = + ImmutableMap.of( + TableProperties.FORMAT_VERSION, + "2", + TableProperties.DEFAULT_FILE_FORMAT, + format.toString()); + return validationCatalog.createTable( + TableIdentifier.of("default", name), schema, spec, properties); + } + + protected void dropTable(String name) { + validationCatalog.dropTable(TableIdentifier.of("default", name), false); + } + + private PositionDelete positionDelete(CharSequence path, Long position) { + PositionDelete posDelete = PositionDelete.create(); + posDelete.set(path, position, null); + return posDelete; + } + + private PositionDelete positionDelete( + Schema tableSchema, CharSequence path, Long position, Object... values) { + PositionDelete posDelete = PositionDelete.create(); + GenericRecord nested = GenericRecord.create(tableSchema); + for (int i = 0; i < values.length; i++) { + nested.set(i, values[i]); + } + posDelete.set(path, position, nested); + return posDelete; + } + + private StructLikeSet expected( + Table testTable, + List> deletes, + StructLike partitionStruct, + int specId, + String deleteFilePath) { + Table deletesTable = + MetadataTableUtils.createMetadataTableInstance( + testTable, MetadataTableType.POSITION_DELETES); + Types.StructType posDeleteSchema = deletesTable.schema().asStruct(); + // Do not compare file paths + if (deleteFilePath == null) { + posDeleteSchema = + TypeUtil.selectNot( + deletesTable.schema(), ImmutableSet.of(MetadataColumns.FILE_PATH_COLUMN_ID)) + .asStruct(); + } + final Types.StructType finalSchema = posDeleteSchema; + StructLikeSet set = StructLikeSet.create(posDeleteSchema); + deletes.stream() + .map( + p -> { + GenericRecord record = GenericRecord.create(finalSchema); + record.setField("file_path", p.path()); + record.setField("pos", p.pos()); + record.setField("row", p.row()); + if (partitionStruct != null) { + record.setField("partition", partitionStruct); + } + record.setField("spec_id", specId); + if (deleteFilePath != null) { + record.setField("delete_file_path", deleteFilePath); + } + return record; + }) + .forEach(set::add); + return set; + } + + private StructLikeSet expected( + Table testTable, + List> deletes, + StructLike partitionStruct, + String deleteFilePath) { + return expected(testTable, deletes, partitionStruct, testTable.spec().specId(), deleteFilePath); + } + + private DataFile dataFile(Table tab, Object... partValues) throws IOException { + return dataFile(tab, partValues, partValues); + } + + @SuppressWarnings("checkstyle:CyclomaticComplexity") + private DataFile dataFile(Table tab, Object[] partDataValues, Object[] partFieldValues) + throws IOException { + GenericRecord record = GenericRecord.create(tab.schema()); + List partitionFieldNames = + tab.spec().fields().stream().map(PartitionField::name).collect(Collectors.toList()); + int idIndex = partitionFieldNames.indexOf("id"); + int dataIndex = partitionFieldNames.indexOf("data"); + Integer idPartition = idIndex != -1 ? (Integer) partDataValues[idIndex] : null; + String dataPartition = dataIndex != -1 ? (String) partDataValues[dataIndex] : null; + + // fill columns with partition source fields, or preset values + List records = + Lists.newArrayList( + record.copy( + "id", + idPartition != null ? idPartition : 29, + "data", + dataPartition != null ? dataPartition : "c"), + record.copy( + "id", + idPartition != null ? idPartition : 43, + "data", + dataPartition != null ? dataPartition : "k"), + record.copy( + "id", + idPartition != null ? idPartition : 61, + "data", + dataPartition != null ? dataPartition : "r"), + record.copy( + "id", + idPartition != null ? idPartition : 89, + "data", + dataPartition != null ? dataPartition : "t")); + + // fill remaining columns with incremental values + List cols = tab.schema().columns(); + if (cols.size() > 2) { + for (int i = 2; i < cols.size(); i++) { + final int pos = i; + records.forEach(r -> r.set(pos, pos)); + } + } + + TestHelpers.Row partitionInfo = TestHelpers.Row.of(partFieldValues); + return FileHelpers.writeDataFile( + tab, + Files.localOutput(File.createTempFile("junit", null, temp.toFile())), + partitionInfo, + records); + } + + private Pair>, DeleteFile> deleteFile( + Table tab, DataFile dataFile, Object... partValues) throws IOException { + return deleteFile(tab, dataFile, partValues, partValues); + } + + private Pair>, DeleteFile> deleteFile( + Table tab, DataFile dataFile, Object[] partDataValues, Object[] partFieldValues) + throws IOException { + List partFields = tab.spec().fields(); + List partitionFieldNames = + partFields.stream().map(PartitionField::name).collect(Collectors.toList()); + int idIndex = partitionFieldNames.indexOf("id"); + int dataIndex = partitionFieldNames.indexOf("data"); + Integer idPartition = idIndex != -1 ? (Integer) partDataValues[idIndex] : null; + String dataPartition = dataIndex != -1 ? (String) partDataValues[dataIndex] : null; + + // fill columns with partition source fields, or preset values + List> deletes = + Lists.newArrayList( + positionDelete( + tab.schema(), + dataFile.location(), + 0L, + idPartition != null ? idPartition : 29, + dataPartition != null ? dataPartition : "c"), + positionDelete( + tab.schema(), + dataFile.location(), + 1L, + idPartition != null ? idPartition : 61, + dataPartition != null ? dataPartition : "r")); + + // fill remaining columns with incremental values + List cols = tab.schema().columns(); + if (cols.size() > 2) { + for (int i = 2; i < cols.size(); i++) { + final int pos = i; + deletes.forEach(d -> d.get(2, GenericRecord.class).set(pos, pos)); + } + } + + TestHelpers.Row partitionInfo = TestHelpers.Row.of(partFieldValues); + + DeleteFile deleteFile = + FileHelpers.writePosDeleteFile( + tab, + Files.localOutput(File.createTempFile("junit", null, temp.toFile())), + partitionInfo, + deletes); + return Pair.of(deletes, deleteFile); + } + + private void stageTask( + Table tab, String fileSetID, CloseableIterable tasks) { + ScanTaskSetManager taskSetManager = ScanTaskSetManager.get(); + taskSetManager.stageTasks(tab, fileSetID, Lists.newArrayList(tasks)); + } + + private void commit( + Table baseTab, + Table posDeletesTable, + String fileSetID, + int expectedSourceFiles, + int expectedTargetFiles) { + PositionDeletesRewriteCoordinator rewriteCoordinator = PositionDeletesRewriteCoordinator.get(); + Set rewrittenFiles = + ScanTaskSetManager.get().fetchTasks(posDeletesTable, fileSetID).stream() + .map(t -> ((PositionDeletesScanTask) t).file()) + .collect(Collectors.toCollection(DeleteFileSet::create)); + Set addedFiles = rewriteCoordinator.fetchNewFiles(posDeletesTable, fileSetID); + + // Assert new files and old files are equal in number but different in paths + assertThat(rewrittenFiles).hasSize(expectedSourceFiles); + assertThat(addedFiles).hasSize(expectedTargetFiles); + + List sortedAddedFiles = + addedFiles.stream().map(f -> f.location()).sorted().collect(Collectors.toList()); + List sortedRewrittenFiles = + rewrittenFiles.stream().map(f -> f.location()).sorted().collect(Collectors.toList()); + assertThat(sortedRewrittenFiles) + .as("Lists should not be the same") + .isNotEqualTo(sortedAddedFiles); + + baseTab + .newRewrite() + .rewriteFiles(ImmutableSet.of(), rewrittenFiles, ImmutableSet.of(), addedFiles) + .commit(); + } + + private void commit(Table baseTab, Table posDeletesTable, String fileSetID, int expectedFiles) { + commit(baseTab, posDeletesTable, fileSetID, expectedFiles, expectedFiles); + } + + private CloseableIterable tasks( + Table posDeletesTable, String partitionColumn, String partitionValue) { + + Expression filter = Expressions.equal("partition." + partitionColumn, partitionValue); + CloseableIterable files = posDeletesTable.newBatchScan().filter(filter).planFiles(); + + // take care of fail to filter in some partition evolution cases + return CloseableIterable.filter( + files, + t -> { + StructLike filePartition = ((PositionDeletesScanTask) t).partition(); + String filePartitionValue = filePartition.get(0, String.class); + return filePartitionValue != null && filePartitionValue.equals(partitionValue); + }); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestReadProjection.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestReadProjection.java new file mode 100644 index 000000000000..5f59c8eef4ba --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestReadProjection.java @@ -0,0 +1,638 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.avro.Schema.Type.UNION; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.within; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Schema; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.types.Comparators; +import org.apache.iceberg.types.Types; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; + +@ExtendWith(ParameterizedTestExtension.class) +public abstract class TestReadProjection { + @Parameter(index = 0) + protected FileFormat format; + + protected abstract Record writeAndRead( + String desc, Schema writeSchema, Schema readSchema, Record record) throws IOException; + + @TempDir protected Path temp; + + @TestTemplate + public void testFullProjection() throws Exception { + Schema schema = + new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get())); + + Record record = GenericRecord.create(schema); + record.setField("id", 34L); + record.setField("data", "test"); + + Record projected = writeAndRead("full_projection", schema, schema, record); + + assertThat((long) projected.getField("id")) + .as("Should contain the correct id value") + .isEqualTo(34L); + + int cmp = + Comparators.charSequences().compare("test", (CharSequence) projected.getField("data")); + + assertThat(cmp).as("Should contain the correct data value").isEqualTo(0); + } + + @TestTemplate + public void testReorderedFullProjection() throws Exception { + // Assume.assumeTrue( + // "Spark's Parquet read support does not support reordered columns", + // !format.equalsIgnoreCase("parquet")); + + Schema schema = + new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get())); + + Record record = GenericRecord.create(schema); + record.setField("id", 34L); + record.setField("data", "test"); + + Schema reordered = + new Schema( + Types.NestedField.optional(1, "data", Types.StringType.get()), + Types.NestedField.required(0, "id", Types.LongType.get())); + + Record projected = writeAndRead("reordered_full_projection", schema, reordered, record); + + assertThat(projected.get(0).toString()) + .as("Should contain the correct 0 value") + .isEqualTo("test"); + assertThat(projected.get(1)).as("Should contain the correct 1 value").isEqualTo(34L); + } + + @TestTemplate + public void testReorderedProjection() throws Exception { + // Assume.assumeTrue( + // "Spark's Parquet read support does not support reordered columns", + // !format.equalsIgnoreCase("parquet")); + + Schema schema = + new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get())); + + Record record = GenericRecord.create(schema); + record.setField("id", 34L); + record.setField("data", "test"); + + Schema reordered = + new Schema( + Types.NestedField.optional(2, "missing_1", Types.StringType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get()), + Types.NestedField.optional(3, "missing_2", Types.LongType.get())); + + Record projected = writeAndRead("reordered_projection", schema, reordered, record); + + assertThat(projected.get(0)).as("Should contain the correct 0 value").isNull(); + assertThat(projected.get(1).toString()) + .as("Should contain the correct 1 value") + .isEqualTo("test"); + assertThat(projected.get(2)).as("Should contain the correct 2 value").isNull(); + } + + @TestTemplate + public void testEmptyProjection() throws Exception { + Schema schema = + new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get())); + + Record record = GenericRecord.create(schema); + record.setField("id", 34L); + record.setField("data", "test"); + + Record projected = writeAndRead("empty_projection", schema, schema.select(), record); + + assertThat(projected).as("Should read a non-null record").isNotNull(); + // this is expected because there are no values + assertThatThrownBy(() -> projected.get(0)).isInstanceOf(ArrayIndexOutOfBoundsException.class); + } + + @TestTemplate + public void testBasicProjection() throws Exception { + Schema writeSchema = + new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get())); + + Record record = GenericRecord.create(writeSchema); + record.setField("id", 34L); + record.setField("data", "test"); + + Schema idOnly = new Schema(Types.NestedField.required(0, "id", Types.LongType.get())); + + Record projected = writeAndRead("basic_projection_id", writeSchema, idOnly, record); + assertThat(projected.getField("data")).as("Should not project data").isNull(); + assertThat((long) projected.getField("id")) + .as("Should contain the correct id value") + .isEqualTo(34L); + + Schema dataOnly = new Schema(Types.NestedField.optional(1, "data", Types.StringType.get())); + + projected = writeAndRead("basic_projection_data", writeSchema, dataOnly, record); + assertThat(projected.getField("id")).as("Should not project id").isNull(); + + int cmp = + Comparators.charSequences().compare("test", (CharSequence) projected.getField("data")); + assertThat(cmp).as("Should contain the correct data value").isEqualTo(0); + } + + @TestTemplate + public void testRename() throws Exception { + Schema writeSchema = + new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get())); + + Record record = GenericRecord.create(writeSchema); + record.setField("id", 34L); + record.setField("data", "test"); + + Schema readSchema = + new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "renamed", Types.StringType.get())); + + Record projected = writeAndRead("project_and_rename", writeSchema, readSchema, record); + assertThat((long) projected.getField("id")) + .as("Should contain the correct id value") + .isEqualTo(34L); + + int cmp = + Comparators.charSequences().compare("test", (CharSequence) projected.getField("renamed")); + assertThat(cmp).as("Should contain the correct data/renamed value").isEqualTo(0); + } + + @TestTemplate + public void testNestedStructProjection() throws Exception { + Schema writeSchema = + new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional( + 3, + "location", + Types.StructType.of( + Types.NestedField.required(1, "lat", Types.FloatType.get()), + Types.NestedField.required(2, "long", Types.FloatType.get())))); + + Record record = GenericRecord.create(writeSchema); + record.setField("id", 34L); + Record location = GenericRecord.create(writeSchema.findType("location").asStructType()); + location.setField("lat", 52.995143f); + location.setField("long", -1.539054f); + record.setField("location", location); + + Schema idOnly = new Schema(Types.NestedField.required(0, "id", Types.LongType.get())); + + Record projected = writeAndRead("id_only", writeSchema, idOnly, record); + Record projectedLocation = (Record) projected.getField("location"); + assertThat((long) projected.getField("id")) + .as("Should contain the correct id value") + .isEqualTo(34L); + assertThat(projectedLocation).as("Should not project location").isNull(); + + Schema latOnly = + new Schema( + Types.NestedField.optional( + 3, + "location", + Types.StructType.of(Types.NestedField.required(1, "lat", Types.FloatType.get())))); + + projected = writeAndRead("latitude_only", writeSchema, latOnly, record); + projectedLocation = (Record) projected.getField("location"); + assertThat(projected.getField("id")).as("Should not project id").isNull(); + assertThat(projected.getField("location")).as("Should project location").isNotNull(); + assertThat(projectedLocation.getField("long")).as("Should not project longitude").isNull(); + assertThat((float) projectedLocation.getField("lat")) + .as("Should project latitude") + .isCloseTo(52.995143f, within(0.000001f)); + + Schema longOnly = + new Schema( + Types.NestedField.optional( + 3, + "location", + Types.StructType.of(Types.NestedField.required(2, "long", Types.FloatType.get())))); + + projected = writeAndRead("longitude_only", writeSchema, longOnly, record); + projectedLocation = (Record) projected.getField("location"); + + assertThat(projected.getField("id")).as("Should not project id").isNull(); + assertThat(projected.getField("location")).as("Should project location").isNotNull(); + assertThat(projectedLocation.getField("lat")).as("Should not project latitude").isNull(); + assertThat((float) projectedLocation.getField("long")) + .as("Should project longitude") + .isCloseTo(-1.539054f, within(0.000001f)); + + Schema locationOnly = writeSchema.select("location"); + projected = writeAndRead("location_only", writeSchema, locationOnly, record); + projectedLocation = (Record) projected.getField("location"); + + assertThat(projected.getField("id")).as("Should not project id").isNull(); + assertThat(projected.getField("location")).as("Should project location").isNotNull(); + assertThat((float) projectedLocation.getField("lat")) + .as("Should project latitude") + .isCloseTo(52.995143f, within(0.000001f)); + assertThat((float) projectedLocation.getField("long")) + .as("Should project longitude") + .isCloseTo(-1.539054f, within(0.000001f)); + } + + @TestTemplate + public void testMapProjection() throws IOException { + Schema writeSchema = + new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional( + 5, + "properties", + Types.MapType.ofOptional(6, 7, Types.StringType.get(), Types.StringType.get()))); + + Map properties = ImmutableMap.of("a", "A", "b", "B"); + + Record record = GenericRecord.create(writeSchema); + record.setField("id", 34L); + record.setField("properties", properties); + + Schema idOnly = new Schema(Types.NestedField.required(0, "id", Types.LongType.get())); + + Record projected = writeAndRead("id_only", writeSchema, idOnly, record); + assertThat((long) projected.getField("id")) + .as("Should contain the correct id value") + .isEqualTo(34L); + assertThat(projected.getField("properties")).as("Should not project properties map").isNull(); + + Schema keyOnly = writeSchema.select("properties.key"); + projected = writeAndRead("key_only", writeSchema, keyOnly, record); + assertThat(projected.getField("id")).as("Should not project id").isNull(); + assertThat(toStringMap((Map) projected.getField("properties"))) + .as("Should project entire map") + .isEqualTo(properties); + + Schema valueOnly = writeSchema.select("properties.value"); + projected = writeAndRead("value_only", writeSchema, valueOnly, record); + assertThat(projected.getField("id")).as("Should not project id").isNull(); + assertThat(toStringMap((Map) projected.getField("properties"))) + .as("Should project entire map") + .isEqualTo(properties); + + Schema mapOnly = writeSchema.select("properties"); + projected = writeAndRead("map_only", writeSchema, mapOnly, record); + assertThat(projected.getField("id")).as("Should not project id").isNull(); + assertThat(toStringMap((Map) projected.getField("properties"))) + .as("Should project entire map") + .isEqualTo(properties); + } + + private Map toStringMap(Map map) { + Map stringMap = Maps.newHashMap(); + for (Map.Entry entry : map.entrySet()) { + if (entry.getValue() instanceof CharSequence) { + stringMap.put(entry.getKey().toString(), entry.getValue().toString()); + } else { + stringMap.put(entry.getKey().toString(), entry.getValue()); + } + } + return stringMap; + } + + @TestTemplate + public void testMapOfStructsProjection() throws IOException { + Schema writeSchema = + new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional( + 5, + "locations", + Types.MapType.ofOptional( + 6, + 7, + Types.StringType.get(), + Types.StructType.of( + Types.NestedField.required(1, "lat", Types.FloatType.get()), + Types.NestedField.required(2, "long", Types.FloatType.get()))))); + + Record record = GenericRecord.create(writeSchema); + record.setField("id", 34L); + Record l1 = GenericRecord.create(writeSchema.findType("locations.value").asStructType()); + l1.setField("lat", 53.992811f); + l1.setField("long", -1.542616f); + Record l2 = GenericRecord.create(l1.struct()); + l2.setField("lat", 52.995143f); + l2.setField("long", -1.539054f); + record.setField("locations", ImmutableMap.of("L1", l1, "L2", l2)); + + Schema idOnly = new Schema(Types.NestedField.required(0, "id", Types.LongType.get())); + + Record projected = writeAndRead("id_only", writeSchema, idOnly, record); + assertThat(34L) + .as("Should contain the correct id value") + .isEqualTo((long) projected.getField("id")); + assertThat(projected.getField("locations")).as("Should not project locations map").isNull(); + + projected = writeAndRead("all_locations", writeSchema, writeSchema.select("locations"), record); + assertThat(projected.getField("id")).as("Should not project id").isNull(); + assertThat(toStringMap((Map) projected.getField("locations"))) + .as("Should project locations map") + .isEqualTo(record.getField("locations")); + + projected = writeAndRead("lat_only", writeSchema, writeSchema.select("locations.lat"), record); + assertThat(projected.getField("id")).as("Should not project id").isNull(); + + Map locations = toStringMap((Map) projected.getField("locations")); + assertThat(locations).as("Should project locations map").isNotNull(); + assertThat(locations.keySet()) + .as("Should contain L1 and L2") + .isEqualTo(Sets.newHashSet("L1", "L2")); + + Record projectedL1 = (Record) locations.get("L1"); + assertThat(projectedL1).as("L1 should not be null").isNotNull(); + assertThat((float) projectedL1.getField("lat")) + .as("L1 should contain lat") + .isCloseTo(53.992811f, within(0.000001f)); + assertThat(projectedL1.getField("long")).as("L1 should not contain long").isNull(); + + Record projectedL2 = (Record) locations.get("L2"); + assertThat(projectedL2).as("L2 should not be null").isNotNull(); + assertThat((float) projectedL2.getField("lat")) + .as("L2 should contain lat") + .isCloseTo(52.995143f, within(0.000001f)); + assertThat(projectedL2.getField("long")).as("L2 should not contain long").isNull(); + + projected = + writeAndRead("long_only", writeSchema, writeSchema.select("locations.long"), record); + assertThat(projected.getField("id")).as("Should not project id").isNull(); + + locations = toStringMap((Map) projected.getField("locations")); + assertThat(locations).as("Should project locations map").isNotNull(); + assertThat(locations.keySet()) + .as("Should contain L1 and L2") + .isEqualTo(Sets.newHashSet("L1", "L2")); + + projectedL1 = (Record) locations.get("L1"); + assertThat(projectedL1).as("L1 should not be null").isNotNull(); + assertThat(projectedL1.getField("lat")).as("L1 should not contain lat").isNull(); + assertThat((float) projectedL1.getField("long")) + .as("L1 should contain long") + .isCloseTo(-1.542616f, within(0.000001f)); + + projectedL2 = (Record) locations.get("L2"); + assertThat(projectedL2).as("L2 should not be null").isNotNull(); + assertThat(projectedL2.getField("lat")).as("L2 should not contain lat").isNull(); + assertThat((float) projectedL2.getField("long")) + .as("L2 should contain long") + .isCloseTo(-1.539054f, within(0.000001f)); + + Schema latitiudeRenamed = + new Schema( + Types.NestedField.optional( + 5, + "locations", + Types.MapType.ofOptional( + 6, + 7, + Types.StringType.get(), + Types.StructType.of( + Types.NestedField.required(1, "latitude", Types.FloatType.get()))))); + + projected = writeAndRead("latitude_renamed", writeSchema, latitiudeRenamed, record); + assertThat(projected.getField("id")).as("Should not project id").isNull(); + locations = toStringMap((Map) projected.getField("locations")); + assertThat(locations).as("Should project locations map").isNotNull(); + assertThat(locations.keySet()) + .as("Should contain L1 and L2") + .isEqualTo(Sets.newHashSet("L1", "L2")); + + projectedL1 = (Record) locations.get("L1"); + assertThat(projectedL1).as("L1 should not be null").isNotNull(); + assertThat((float) projectedL1.getField("latitude")) + .as("L1 should contain latitude") + .isCloseTo(53.992811f, within(0.000001f)); + assertThat(projectedL1.getField("lat")).as("L1 should not contain lat").isNull(); + assertThat(projectedL1.getField("long")).as("L1 should not contain long").isNull(); + + projectedL2 = (Record) locations.get("L2"); + assertThat(projectedL2).as("L2 should not be null").isNotNull(); + assertThat((float) projectedL2.getField("latitude")) + .as("L2 should contain latitude") + .isCloseTo(52.995143f, within(0.000001f)); + assertThat(projectedL2.getField("lat")).as("L2 should not contain lat").isNull(); + assertThat(projectedL2.getField("long")).as("L2 should not contain long").isNull(); + } + + @TestTemplate + public void testListProjection() throws IOException { + Schema writeSchema = + new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional( + 10, "values", Types.ListType.ofOptional(11, Types.LongType.get()))); + + List values = ImmutableList.of(56L, 57L, 58L); + + Record record = GenericRecord.create(writeSchema); + record.setField("id", 34L); + record.setField("values", values); + + Schema idOnly = new Schema(Types.NestedField.required(0, "id", Types.LongType.get())); + + Record projected = writeAndRead("id_only", writeSchema, idOnly, record); + assertThat((long) projected.getField("id")) + .as("Should contain the correct id value") + .isEqualTo(34L); + assertThat(projected.getField("values")).as("Should not project values list").isNull(); + + Schema elementOnly = writeSchema.select("values.element"); + projected = writeAndRead("element_only", writeSchema, elementOnly, record); + assertThat(projected.getField("id")).as("Should not project id").isNull(); + assertThat(projected.getField("values")).as("Should project entire list").isEqualTo(values); + + Schema listOnly = writeSchema.select("values"); + projected = writeAndRead("list_only", writeSchema, listOnly, record); + assertThat(projected.getField("id")).as("Should not project id").isNull(); + assertThat(projected.getField("values")).as("Should project entire list").isEqualTo(values); + } + + @TestTemplate + @SuppressWarnings("unchecked") + public void testListOfStructsProjection() throws IOException { + Schema writeSchema = + new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional( + 22, + "points", + Types.ListType.ofOptional( + 21, + Types.StructType.of( + Types.NestedField.required(19, "x", Types.IntegerType.get()), + Types.NestedField.optional(18, "y", Types.IntegerType.get()))))); + + Record record = GenericRecord.create(writeSchema); + record.setField("id", 34L); + Record p1 = GenericRecord.create(writeSchema.findType("points.element").asStructType()); + p1.setField("x", 1); + p1.setField("y", 2); + Record p2 = GenericRecord.create(p1.struct()); + p2.setField("x", 3); + p2.setField("y", null); + record.setField("points", ImmutableList.of(p1, p2)); + + Schema idOnly = new Schema(Types.NestedField.required(0, "id", Types.LongType.get())); + + Record projected = writeAndRead("id_only", writeSchema, idOnly, record); + assertThat((long) projected.getField("id")) + .as("Should contain the correct id value") + .isEqualTo(34L); + assertThat(projected.getField("points")).as("Should not project points list").isNull(); + + projected = writeAndRead("all_points", writeSchema, writeSchema.select("points"), record); + assertThat(projected.getField("id")).as("Should not project id").isNull(); + assertThat(projected.getField("points")) + .as("Should project points list") + .isEqualTo(record.getField("points")); + + projected = writeAndRead("x_only", writeSchema, writeSchema.select("points.x"), record); + assertThat(projected.getField("id")).as("Should not project id").isNull(); + assertThat(projected.getField("points")).as("Should project points list").isNotNull(); + + List points = (List) projected.getField("points"); + assertThat(points).as("Should read 2 points").hasSize(2); + + Record projectedP1 = points.get(0); + assertThat((int) projectedP1.getField("x")).as("Should project x").isEqualTo(1); + assertThat(projected.getField("y")).as("Should not project y").isNull(); + + Record projectedP2 = points.get(1); + assertThat((int) projectedP2.getField("x")).as("Should project x").isEqualTo(3); + assertThat(projected.getField("y")).as("Should not project y").isNull(); + + projected = writeAndRead("y_only", writeSchema, writeSchema.select("points.y"), record); + assertThat(projected.getField("id")).as("Should not project id").isNull(); + assertThat(projected.getField("points")).as("Should project points list").isNotNull(); + + points = (List) projected.getField("points"); + assertThat(points).as("Should read 2 points").hasSize(2); + + projectedP1 = points.get(0); + assertThat(projectedP1.getField("x")).as("Should not project x").isNull(); + assertThat((int) projectedP1.getField("y")).as("Should project y").isEqualTo(2); + + projectedP2 = points.get(1); + assertThat(projectedP2.getField("x")).as("Should not project x").isNull(); + assertThat(projectedP2.getField("y")).as("Should not project y").isNull(); + + Schema yRenamed = + new Schema( + Types.NestedField.optional( + 22, + "points", + Types.ListType.ofOptional( + 21, + Types.StructType.of( + Types.NestedField.optional(18, "z", Types.IntegerType.get()))))); + + projected = writeAndRead("y_renamed", writeSchema, yRenamed, record); + assertThat(projected.getField("id")).as("Should not project id").isNull(); + assertThat(projected.getField("points")).as("Should project points list").isNotNull(); + + points = (List) projected.getField("points"); + assertThat(points).as("Should read 2 points").hasSize(2); + + projectedP1 = points.get(0); + assertThat(projectedP1.getField("x")).as("Should not project x").isNull(); + assertThat(projectedP1.getField("y")).as("Should not project y").isNull(); + assertThat((int) projectedP1.getField("z")).as("Should project z").isEqualTo(2); + + projectedP2 = points.get(1); + assertThat(projectedP2.getField("x")).as("Should not project x").isNull(); + assertThat(projectedP2.getField("y")).as("Should not project y").isNull(); + assertThat(projectedP2.getField("z")).as("Should project null z").isNull(); + + Schema zAdded = + new Schema( + Types.NestedField.optional( + 22, + "points", + Types.ListType.ofOptional( + 21, + Types.StructType.of( + Types.NestedField.required(19, "x", Types.IntegerType.get()), + Types.NestedField.optional(18, "y", Types.IntegerType.get()), + Types.NestedField.optional(20, "z", Types.IntegerType.get()))))); + + projected = writeAndRead("z_added", writeSchema, zAdded, record); + assertThat(projected.getField("id")).as("Should not project id").isNull(); + assertThat(projected.getField("points")).as("Should project points list").isNotNull(); + + points = (List) projected.getField("points"); + assertThat(points).as("Should read 2 points").hasSize(2); + + projectedP1 = points.get(0); + assertThat((int) projectedP1.getField("x")).as("Should project x").isEqualTo(1); + assertThat((int) projectedP1.getField("y")).as("Should project y").isEqualTo(2); + assertThat(projectedP1.getField("z")).as("Should contain null z").isNull(); + + projectedP2 = points.get(1); + assertThat((int) projectedP2.getField("x")).as("Should project x").isEqualTo(3); + assertThat(projectedP2.getField("y")).as("Should project null y").isNull(); + assertThat(projectedP2.getField("z")).as("Should contain null z").isNull(); + } + + private static org.apache.avro.Schema fromOption(org.apache.avro.Schema schema) { + Preconditions.checkArgument( + schema.getType() == UNION, "Expected union schema but was passed: %s", schema); + Preconditions.checkArgument( + schema.getTypes().size() == 2, "Expected optional schema, but was passed: %s", schema); + if (schema.getTypes().get(0).getType() == org.apache.avro.Schema.Type.NULL) { + return schema.getTypes().get(1); + } else { + return schema.getTypes().get(0); + } + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestRequiredDistributionAndOrdering.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestRequiredDistributionAndOrdering.java new file mode 100644 index 000000000000..55fd2cefe2e6 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestRequiredDistributionAndOrdering.java @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.List; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.CatalogTestBase; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestRequiredDistributionAndOrdering extends CatalogTestBase { + + @AfterEach + public void dropTestTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testDefaultLocalSort() throws NoSuchTableException { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (c3)", + tableName); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A")); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + // should insert a local sort by partition columns by default + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @TestTemplate + public void testPartitionColumnsArePrependedForRangeDistribution() throws NoSuchTableException { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (c3)", + tableName); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A")); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + Table table = validationCatalog.loadTable(tableIdent); + + // should automatically prepend partition columns to the ordering + table + .updateProperties() + .set(TableProperties.WRITE_DISTRIBUTION_MODE, TableProperties.WRITE_DISTRIBUTION_MODE_RANGE) + .commit(); + table.replaceSortOrder().asc("c1").asc("c2").commit(); + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @TestTemplate + public void testSortOrderIncludesPartitionColumns() throws NoSuchTableException { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (c3)", + tableName); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A")); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + Table table = validationCatalog.loadTable(tableIdent); + + // should succeed with a correct sort order + table.replaceSortOrder().asc("c3").asc("c1").asc("c2").commit(); + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @TestTemplate + public void testDisabledDistributionAndOrdering() { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (c3)", + tableName); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A")); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + // should fail if ordering is disabled + assertThatThrownBy( + () -> + inputDF + .writeTo(tableName) + .option(SparkWriteOptions.USE_TABLE_DISTRIBUTION_AND_ORDERING, "false") + .option(SparkWriteOptions.FANOUT_ENABLED, "false") + .append()) + .cause() + .isInstanceOf(IllegalStateException.class) + .hasMessageStartingWith( + "Incoming records violate the writer assumption that records are clustered by spec " + + "and by partition within each spec. Either cluster the incoming records or switch to fanout writers."); + } + + @TestTemplate + public void testHashDistribution() throws NoSuchTableException { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (c3)", + tableName); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A")); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + Table table = validationCatalog.loadTable(tableIdent); + + // should automatically prepend partition columns to the local ordering after hash distribution + table + .updateProperties() + .set(TableProperties.WRITE_DISTRIBUTION_MODE, TableProperties.WRITE_DISTRIBUTION_MODE_HASH) + .commit(); + table.replaceSortOrder().asc("c1").asc("c2").commit(); + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @TestTemplate + public void testSortBucketTransformsWithoutExtensions() throws NoSuchTableException { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, c1))", + tableName); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBB", "B"), + new ThreeColumnRecord(3, "BBBB", "B"), + new ThreeColumnRecord(4, "BBBB", "B")); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + inputDF.writeTo(tableName).append(); + + List expected = + ImmutableList.of( + row(1, null, "A"), row(2, "BBBB", "B"), row(3, "BBBB", "B"), row(4, "BBBB", "B")); + + assertEquals("Rows must match", expected, sql("SELECT * FROM %s ORDER BY c1", tableName)); + } + + @TestTemplate + public void testRangeDistributionWithQuotedColumnsNames() throws NoSuchTableException { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING, `c.3` STRING) " + + "USING iceberg " + + "PARTITIONED BY (`c.3`)", + tableName); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A")); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = + ds.selectExpr("c1", "c2", "c3 as `c.3`").coalesce(1).sortWithinPartitions("c1"); + + Table table = validationCatalog.loadTable(tableIdent); + + table + .updateProperties() + .set(TableProperties.WRITE_DISTRIBUTION_MODE, TableProperties.WRITE_DISTRIBUTION_MODE_RANGE) + .commit(); + table.replaceSortOrder().asc("c1").asc("c2").commit(); + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @TestTemplate + public void testHashDistributionWithQuotedColumnsNames() throws NoSuchTableException { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING, `c``3` STRING) " + + "USING iceberg " + + "PARTITIONED BY (`c``3`)", + tableName); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A")); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = + ds.selectExpr("c1", "c2", "c3 as `c``3`").coalesce(1).sortWithinPartitions("c1"); + + Table table = validationCatalog.loadTable(tableIdent); + + table + .updateProperties() + .set(TableProperties.WRITE_DISTRIBUTION_MODE, TableProperties.WRITE_DISTRIBUTION_MODE_HASH) + .commit(); + table.replaceSortOrder().asc("c1").asc("c2").commit(); + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestRuntimeFiltering.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestRuntimeFiltering.java new file mode 100644 index 000000000000..e7346e270f38 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestRuntimeFiltering.java @@ -0,0 +1,510 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.PlanningMode.DISTRIBUTED; +import static org.apache.iceberg.PlanningMode.LOCAL; +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.expr; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.List; +import java.util.Set; +import org.apache.commons.lang3.StringUtils; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.PlanningMode; +import org.apache.iceberg.Table; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.spark.TestBaseWithCatalog; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestRuntimeFiltering extends TestBaseWithCatalog { + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}, planningMode = {3}") + public static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.HADOOP.catalogName(), + SparkCatalogConfig.HADOOP.implementation(), + SparkCatalogConfig.HADOOP.properties(), + LOCAL + }, + { + SparkCatalogConfig.HADOOP.catalogName(), + SparkCatalogConfig.HADOOP.implementation(), + SparkCatalogConfig.HADOOP.properties(), + DISTRIBUTED + } + }; + } + + @Parameter(index = 3) + private PlanningMode planningMode; + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS dim"); + } + + @TestTemplate + public void testIdentityPartitionedTable() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", + tableName); + configurePlanningMode(planningMode); + + Dataset df = + spark + .range(1, 100) + .withColumn("date", date_add(expr("DATE '1970-01-01'"), expr("CAST(id % 4 AS INT)"))) + .withColumn("ts", expr("TO_TIMESTAMP(date)")) + .withColumn("data", expr("CAST(date AS STRING)")) + .select("id", "data", "date", "ts"); + + df.coalesce(1).writeTo(tableName).option(SparkWriteOptions.FANOUT_ENABLED, "true").append(); + + sql("CREATE TABLE dim (id BIGINT, date DATE) USING parquet"); + Dataset dimDF = + spark.range(1, 10).withColumn("date", expr("DATE '1970-01-02'")).select("id", "date"); + dimDF.coalesce(1).write().mode("append").insertInto("dim"); + + String query = + String.format( + "SELECT f.* FROM %s f JOIN dim d ON f.date = d.date AND d.id = 1 ORDER BY id", + tableName); + + assertQueryContainsRuntimeFilter(query); + + deleteNotMatchingFiles(Expressions.equal("date", 1), 3); + + assertEquals( + "Should have expected rows", + sql("SELECT * FROM %s WHERE date = DATE '1970-01-02' ORDER BY id", tableName), + sql(query)); + } + + @TestTemplate + public void testBucketedTable() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (bucket(8, id))", + tableName); + configurePlanningMode(planningMode); + + Dataset df = + spark + .range(1, 100) + .withColumn("date", date_add(expr("DATE '1970-01-01'"), expr("CAST(id % 4 AS INT)"))) + .withColumn("ts", expr("TO_TIMESTAMP(date)")) + .withColumn("data", expr("CAST(date AS STRING)")) + .select("id", "data", "date", "ts"); + + df.coalesce(1).writeTo(tableName).option(SparkWriteOptions.FANOUT_ENABLED, "true").append(); + + sql("CREATE TABLE dim (id BIGINT, date DATE) USING parquet"); + Dataset dimDF = + spark.range(1, 2).withColumn("date", expr("DATE '1970-01-02'")).select("id", "date"); + dimDF.coalesce(1).write().mode("append").insertInto("dim"); + + String query = + String.format( + "SELECT f.* FROM %s f JOIN dim d ON f.id = d.id AND d.date = DATE '1970-01-02' ORDER BY date", + tableName); + + assertQueryContainsRuntimeFilter(query); + + deleteNotMatchingFiles(Expressions.equal("id", 1), 7); + + assertEquals( + "Should have expected rows", + sql("SELECT * FROM %s WHERE id = 1 ORDER BY date", tableName), + sql(query)); + } + + @TestTemplate + public void testRenamedSourceColumnTable() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (bucket(8, id))", + tableName); + configurePlanningMode(planningMode); + + Dataset df = + spark + .range(1, 100) + .withColumn("date", date_add(expr("DATE '1970-01-01'"), expr("CAST(id % 4 AS INT)"))) + .withColumn("ts", expr("TO_TIMESTAMP(date)")) + .withColumn("data", expr("CAST(date AS STRING)")) + .select("id", "data", "date", "ts"); + + df.coalesce(1).writeTo(tableName).option(SparkWriteOptions.FANOUT_ENABLED, "true").append(); + + sql("CREATE TABLE dim (id BIGINT, date DATE) USING parquet"); + Dataset dimDF = + spark.range(1, 2).withColumn("date", expr("DATE '1970-01-02'")).select("id", "date"); + dimDF.coalesce(1).write().mode("append").insertInto("dim"); + + sql("ALTER TABLE %s RENAME COLUMN id TO row_id", tableName); + + String query = + String.format( + "SELECT f.* FROM %s f JOIN dim d ON f.row_id = d.id AND d.date = DATE '1970-01-02' ORDER BY date", + tableName); + + assertQueryContainsRuntimeFilter(query); + + deleteNotMatchingFiles(Expressions.equal("row_id", 1), 7); + + assertEquals( + "Should have expected rows", + sql("SELECT * FROM %s WHERE row_id = 1 ORDER BY date", tableName), + sql(query)); + } + + @TestTemplate + public void testMultipleRuntimeFilters() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (data, bucket(8, id))", + tableName); + configurePlanningMode(planningMode); + + Dataset df = + spark + .range(1, 100) + .withColumn("date", date_add(expr("DATE '1970-01-01'"), expr("CAST(id % 4 AS INT)"))) + .withColumn("ts", expr("TO_TIMESTAMP(date)")) + .withColumn("data", expr("CAST(date AS STRING)")) + .select("id", "data", "date", "ts"); + + df.coalesce(1).writeTo(tableName).option(SparkWriteOptions.FANOUT_ENABLED, "true").append(); + + sql("CREATE TABLE dim (id BIGINT, date DATE, data STRING) USING parquet"); + Dataset dimDF = + spark + .range(1, 2) + .withColumn("date", expr("DATE '1970-01-02'")) + .withColumn("data", expr("'1970-01-02'")) + .select("id", "date", "data"); + dimDF.coalesce(1).write().mode("append").insertInto("dim"); + + String query = + String.format( + "SELECT f.* FROM %s f JOIN dim d ON f.id = d.id AND f.data = d.data AND d.date = DATE '1970-01-02'", + tableName); + + assertQueryContainsRuntimeFilters(query, 2, "Query should have 2 runtime filters"); + + deleteNotMatchingFiles(Expressions.equal("id", 1), 31); + + assertEquals( + "Should have expected rows", + sql("SELECT * FROM %s WHERE id = 1 AND data = '1970-01-02'", tableName), + sql(query)); + } + + @TestTemplate + public void testCaseSensitivityOfRuntimeFilters() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (data, bucket(8, id))", + tableName); + configurePlanningMode(planningMode); + + Dataset df = + spark + .range(1, 100) + .withColumn("date", date_add(expr("DATE '1970-01-01'"), expr("CAST(id % 4 AS INT)"))) + .withColumn("ts", expr("TO_TIMESTAMP(date)")) + .withColumn("data", expr("CAST(date AS STRING)")) + .select("id", "data", "date", "ts"); + + df.coalesce(1).writeTo(tableName).option(SparkWriteOptions.FANOUT_ENABLED, "true").append(); + + sql("CREATE TABLE dim (id BIGINT, date DATE, data STRING) USING parquet"); + Dataset dimDF = + spark + .range(1, 2) + .withColumn("date", expr("DATE '1970-01-02'")) + .withColumn("data", expr("'1970-01-02'")) + .select("id", "date", "data"); + dimDF.coalesce(1).write().mode("append").insertInto("dim"); + + String caseInsensitiveQuery = + String.format( + "select f.* from %s F join dim d ON f.Id = d.iD and f.DaTa = d.dAtA and d.dAtE = date '1970-01-02'", + tableName); + + assertQueryContainsRuntimeFilters( + caseInsensitiveQuery, 2, "Query should have 2 runtime filters"); + + deleteNotMatchingFiles(Expressions.equal("id", 1), 31); + + assertEquals( + "Should have expected rows", + sql("SELECT * FROM %s WHERE id = 1 AND data = '1970-01-02'", tableName), + sql(caseInsensitiveQuery)); + } + + @TestTemplate + public void testBucketedTableWithMultipleSpecs() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) USING iceberg", + tableName); + configurePlanningMode(planningMode); + + Dataset df1 = + spark + .range(1, 100) + .withColumn("date", date_add(expr("DATE '1970-01-01'"), expr("CAST(id % 2 AS INT)"))) + .withColumn("ts", expr("TO_TIMESTAMP(date)")) + .withColumn("data", expr("CAST(date AS STRING)")) + .select("id", "data", "date", "ts"); + + df1.coalesce(1).writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + table.updateSpec().addField(Expressions.bucket("id", 8)).commit(); + + sql("REFRESH TABLE %s", tableName); + + Dataset df2 = + spark + .range(1, 100) + .withColumn("date", date_add(expr("DATE '1970-01-01'"), expr("CAST(id % 4 AS INT)"))) + .withColumn("ts", expr("TO_TIMESTAMP(date)")) + .withColumn("data", expr("CAST(date AS STRING)")) + .select("id", "data", "date", "ts"); + + df2.coalesce(1).writeTo(tableName).option(SparkWriteOptions.FANOUT_ENABLED, "true").append(); + + sql("CREATE TABLE dim (id BIGINT, date DATE) USING parquet"); + Dataset dimDF = + spark.range(1, 2).withColumn("date", expr("DATE '1970-01-02'")).select("id", "date"); + dimDF.coalesce(1).write().mode("append").insertInto("dim"); + + String query = + String.format( + "SELECT f.* FROM %s f JOIN dim d ON f.id = d.id AND d.date = DATE '1970-01-02' ORDER BY date", + tableName); + + assertQueryContainsRuntimeFilter(query); + + deleteNotMatchingFiles(Expressions.equal("id", 1), 7); + + assertEquals( + "Should have expected rows", + sql("SELECT * FROM %s WHERE id = 1 ORDER BY date", tableName), + sql(query)); + } + + @TestTemplate + public void testSourceColumnWithDots() throws NoSuchTableException { + sql( + "CREATE TABLE %s (`i.d` BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (bucket(8, `i.d`))", + tableName); + configurePlanningMode(planningMode); + + Dataset df = + spark + .range(1, 100) + .withColumnRenamed("id", "i.d") + .withColumn("date", date_add(expr("DATE '1970-01-01'"), expr("CAST(`i.d` % 4 AS INT)"))) + .withColumn("ts", expr("TO_TIMESTAMP(date)")) + .withColumn("data", expr("CAST(date AS STRING)")) + .select("`i.d`", "data", "date", "ts"); + + df.coalesce(1).writeTo(tableName).option(SparkWriteOptions.FANOUT_ENABLED, "true").append(); + + sql("SELECT * FROM %s WHERE `i.d` = 1", tableName); + + sql("CREATE TABLE dim (id BIGINT, date DATE) USING parquet"); + Dataset dimDF = + spark.range(1, 2).withColumn("date", expr("DATE '1970-01-02'")).select("id", "date"); + dimDF.coalesce(1).write().mode("append").insertInto("dim"); + + String query = + String.format( + "SELECT f.* FROM %s f JOIN dim d ON f.`i.d` = d.id AND d.date = DATE '1970-01-02' ORDER BY date", + tableName); + + assertQueryContainsRuntimeFilter(query); + + deleteNotMatchingFiles(Expressions.equal("i.d", 1), 7); + + sql(query); + + assertEquals( + "Should have expected rows", + sql("SELECT * FROM %s WHERE `i.d` = 1 ORDER BY date", tableName), + sql(query)); + } + + @TestTemplate + public void testSourceColumnWithBackticks() throws NoSuchTableException { + sql( + "CREATE TABLE %s (`i``d` BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (bucket(8, `i``d`))", + tableName); + configurePlanningMode(planningMode); + + Dataset df = + spark + .range(1, 100) + .withColumnRenamed("id", "i`d") + .withColumn( + "date", date_add(expr("DATE '1970-01-01'"), expr("CAST(`i``d` % 4 AS INT)"))) + .withColumn("ts", expr("TO_TIMESTAMP(date)")) + .withColumn("data", expr("CAST(date AS STRING)")) + .select("`i``d`", "data", "date", "ts"); + + df.coalesce(1).writeTo(tableName).option(SparkWriteOptions.FANOUT_ENABLED, "true").append(); + + sql("CREATE TABLE dim (id BIGINT, date DATE) USING parquet"); + Dataset dimDF = + spark.range(1, 2).withColumn("date", expr("DATE '1970-01-02'")).select("id", "date"); + dimDF.coalesce(1).write().mode("append").insertInto("dim"); + + String query = + String.format( + "SELECT f.* FROM %s f JOIN dim d ON f.`i``d` = d.id AND d.date = DATE '1970-01-02' ORDER BY date", + tableName); + + assertQueryContainsRuntimeFilter(query); + + deleteNotMatchingFiles(Expressions.equal("i`d", 1), 7); + + assertEquals( + "Should have expected rows", + sql("SELECT * FROM %s WHERE `i``d` = 1 ORDER BY date", tableName), + sql(query)); + } + + @TestTemplate + public void testUnpartitionedTable() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) USING iceberg", + tableName); + configurePlanningMode(planningMode); + + Dataset df = + spark + .range(1, 100) + .withColumn("date", date_add(expr("DATE '1970-01-01'"), expr("CAST(id % 4 AS INT)"))) + .withColumn("ts", expr("TO_TIMESTAMP(date)")) + .withColumn("data", expr("CAST(date AS STRING)")) + .select("id", "data", "date", "ts"); + + df.coalesce(1).writeTo(tableName).append(); + + sql("CREATE TABLE dim (id BIGINT, date DATE) USING parquet"); + Dataset dimDF = + spark.range(1, 2).withColumn("date", expr("DATE '1970-01-02'")).select("id", "date"); + dimDF.coalesce(1).write().mode("append").insertInto("dim"); + + String query = + String.format( + "SELECT f.* FROM %s f JOIN dim d ON f.id = d.id AND d.date = DATE '1970-01-02' ORDER BY date", + tableName); + + assertQueryContainsNoRuntimeFilter(query); + + assertEquals( + "Should have expected rows", + sql("SELECT * FROM %s WHERE id = 1 ORDER BY date", tableName), + sql(query)); + } + + private void assertQueryContainsRuntimeFilter(String query) { + assertQueryContainsRuntimeFilters(query, 1, "Query should have 1 runtime filter"); + } + + private void assertQueryContainsNoRuntimeFilter(String query) { + assertQueryContainsRuntimeFilters(query, 0, "Query should have no runtime filters"); + } + + private void assertQueryContainsRuntimeFilters( + String query, int expectedFilterCount, String errorMessage) { + List output = spark.sql("EXPLAIN EXTENDED " + query).collectAsList(); + String plan = output.get(0).getString(0); + int actualFilterCount = StringUtils.countMatches(plan, "dynamicpruningexpression"); + assertThat(actualFilterCount).as(errorMessage).isEqualTo(expectedFilterCount); + } + + // delete files that don't match the filter to ensure dynamic filtering works and only required + // files are read + private void deleteNotMatchingFiles(Expression filter, int expectedDeletedFileCount) { + Table table = validationCatalog.loadTable(tableIdent); + FileIO io = table.io(); + + Set matchingFileLocations = Sets.newHashSet(); + try (CloseableIterable files = table.newScan().filter(filter).planFiles()) { + for (FileScanTask file : files) { + String path = file.file().location(); + matchingFileLocations.add(path); + } + } catch (IOException e) { + throw new UncheckedIOException(e); + } + + Set deletedFileLocations = Sets.newHashSet(); + try (CloseableIterable files = table.newScan().planFiles()) { + for (FileScanTask file : files) { + String path = file.file().location(); + if (!matchingFileLocations.contains(path)) { + io.deleteFile(path); + deletedFileLocations.add(path); + } + } + } catch (IOException e) { + throw new UncheckedIOException(e); + } + + assertThat(deletedFileLocations) + .as("Deleted unexpected number of files") + .hasSize(expectedDeletedFileCount); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSnapshotSelection.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSnapshotSelection.java new file mode 100644 index 000000000000..a7334a580ca6 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSnapshotSelection.java @@ -0,0 +1,586 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.PlanningMode.DISTRIBUTED; +import static org.apache.iceberg.PlanningMode.LOCAL; +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SparkSession; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestSnapshotSelection { + + @Parameters(name = "properties = {0}") + public static Object[] parameters() { + return new Object[][] { + { + ImmutableMap.of( + TableProperties.DATA_PLANNING_MODE, LOCAL.modeName(), + TableProperties.DELETE_PLANNING_MODE, LOCAL.modeName()) + }, + { + ImmutableMap.of( + TableProperties.DATA_PLANNING_MODE, DISTRIBUTED.modeName(), + TableProperties.DELETE_PLANNING_MODE, DISTRIBUTED.modeName()) + } + }; + } + + private static final Configuration CONF = new Configuration(); + private static final Schema SCHEMA = + new Schema( + optional(1, "id", Types.IntegerType.get()), optional(2, "data", Types.StringType.get())); + + @TempDir private Path temp; + + private static SparkSession spark = null; + + @Parameter(index = 0) + private Map properties; + + @BeforeAll + public static void startSpark() { + TestSnapshotSelection.spark = SparkSession.builder().master("local[2]").getOrCreate(); + } + + @AfterAll + public static void stopSpark() { + SparkSession currentSpark = TestSnapshotSelection.spark; + TestSnapshotSelection.spark = null; + currentSpark.stop(); + } + + @TestTemplate + public void testSnapshotSelectionById() throws IOException { + String tableLocation = temp.resolve("iceberg-table").toFile().toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, properties, tableLocation); + + // produce the first snapshot + List firstBatchRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset firstDf = spark.createDataFrame(firstBatchRecords, SimpleRecord.class); + firstDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + // produce the second snapshot + List secondBatchRecords = + Lists.newArrayList( + new SimpleRecord(4, "d"), new SimpleRecord(5, "e"), new SimpleRecord(6, "f")); + Dataset secondDf = spark.createDataFrame(secondBatchRecords, SimpleRecord.class); + secondDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + assertThat(table.snapshots()).as("Expected 2 snapshots").hasSize(2); + + // verify records in the current snapshot + Dataset currentSnapshotResult = spark.read().format("iceberg").load(tableLocation); + List currentSnapshotRecords = + currentSnapshotResult.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(firstBatchRecords); + expectedRecords.addAll(secondBatchRecords); + assertThat(currentSnapshotRecords) + .as("Current snapshot rows should match") + .isEqualTo(expectedRecords); + + // verify records in the previous snapshot + Snapshot currentSnapshot = table.currentSnapshot(); + Long parentSnapshotId = currentSnapshot.parentId(); + Dataset previousSnapshotResult = + spark.read().format("iceberg").option("snapshot-id", parentSnapshotId).load(tableLocation); + List previousSnapshotRecords = + previousSnapshotResult.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + assertThat(previousSnapshotRecords) + .as("Previous snapshot rows should match") + .isEqualTo(firstBatchRecords); + } + + @TestTemplate + public void testSnapshotSelectionByTimestamp() throws IOException { + String tableLocation = temp.resolve("iceberg-table").toFile().toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, properties, tableLocation); + + // produce the first snapshot + List firstBatchRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset firstDf = spark.createDataFrame(firstBatchRecords, SimpleRecord.class); + firstDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + // remember the time when the first snapshot was valid + long firstSnapshotTimestamp = System.currentTimeMillis(); + + // produce the second snapshot + List secondBatchRecords = + Lists.newArrayList( + new SimpleRecord(4, "d"), new SimpleRecord(5, "e"), new SimpleRecord(6, "f")); + Dataset secondDf = spark.createDataFrame(secondBatchRecords, SimpleRecord.class); + secondDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + assertThat(table.snapshots()).as("Expected 2 snapshots").hasSize(2); + + // verify records in the current snapshot + Dataset currentSnapshotResult = spark.read().format("iceberg").load(tableLocation); + List currentSnapshotRecords = + currentSnapshotResult.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(firstBatchRecords); + expectedRecords.addAll(secondBatchRecords); + assertThat(currentSnapshotRecords) + .as("Current snapshot rows should match") + .isEqualTo(expectedRecords); + + // verify records in the previous snapshot + Dataset previousSnapshotResult = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.AS_OF_TIMESTAMP, firstSnapshotTimestamp) + .load(tableLocation); + List previousSnapshotRecords = + previousSnapshotResult.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + assertThat(previousSnapshotRecords) + .as("Previous snapshot rows should match") + .isEqualTo(firstBatchRecords); + } + + @TestTemplate + public void testSnapshotSelectionByInvalidSnapshotId() throws IOException { + String tableLocation = temp.resolve("iceberg-table").toFile().toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + tables.create(SCHEMA, spec, properties, tableLocation); + + Dataset df = spark.read().format("iceberg").option("snapshot-id", -10).load(tableLocation); + + assertThatThrownBy(df::collectAsList) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot find snapshot with ID -10"); + } + + @TestTemplate + public void testSnapshotSelectionByInvalidTimestamp() throws IOException { + long timestamp = System.currentTimeMillis(); + + String tableLocation = temp.resolve("iceberg-table").toFile().toString(); + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + tables.create(SCHEMA, spec, properties, tableLocation); + + assertThatThrownBy( + () -> + spark + .read() + .format("iceberg") + .option(SparkReadOptions.AS_OF_TIMESTAMP, timestamp) + .load(tableLocation)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Cannot find a snapshot older than"); + } + + @TestTemplate + public void testSnapshotSelectionBySnapshotIdAndTimestamp() throws IOException { + String tableLocation = temp.resolve("iceberg-table").toFile().toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, properties, tableLocation); + + List firstBatchRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset firstDf = spark.createDataFrame(firstBatchRecords, SimpleRecord.class); + firstDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + long timestamp = System.currentTimeMillis(); + long snapshotId = table.currentSnapshot().snapshotId(); + + assertThatThrownBy( + () -> + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SNAPSHOT_ID, snapshotId) + .option(SparkReadOptions.AS_OF_TIMESTAMP, timestamp) + .load(tableLocation)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Can specify only one of snapshot-id") + .hasMessageContaining("as-of-timestamp") + .hasMessageContaining("branch") + .hasMessageContaining("tag"); + } + + @TestTemplate + public void testSnapshotSelectionByTag() throws IOException { + String tableLocation = temp.resolve("iceberg-table").toFile().toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, properties, tableLocation); + + // produce the first snapshot + List firstBatchRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset firstDf = spark.createDataFrame(firstBatchRecords, SimpleRecord.class); + firstDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + table.manageSnapshots().createTag("tag", table.currentSnapshot().snapshotId()).commit(); + + // produce the second snapshot + List secondBatchRecords = + Lists.newArrayList( + new SimpleRecord(4, "d"), new SimpleRecord(5, "e"), new SimpleRecord(6, "f")); + Dataset secondDf = spark.createDataFrame(secondBatchRecords, SimpleRecord.class); + secondDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + // verify records in the current snapshot by tag + Dataset currentSnapshotResult = + spark.read().format("iceberg").option("tag", "tag").load(tableLocation); + List currentSnapshotRecords = + currentSnapshotResult.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(firstBatchRecords); + assertThat(currentSnapshotRecords) + .as("Current snapshot rows should match") + .isEqualTo(expectedRecords); + } + + @TestTemplate + public void testSnapshotSelectionByBranch() throws IOException { + String tableLocation = temp.resolve("iceberg-table").toFile().toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, properties, tableLocation); + + // produce the first snapshot + List firstBatchRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset firstDf = spark.createDataFrame(firstBatchRecords, SimpleRecord.class); + firstDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + table.manageSnapshots().createBranch("branch", table.currentSnapshot().snapshotId()).commit(); + + // produce the second snapshot + List secondBatchRecords = + Lists.newArrayList( + new SimpleRecord(4, "d"), new SimpleRecord(5, "e"), new SimpleRecord(6, "f")); + Dataset secondDf = spark.createDataFrame(secondBatchRecords, SimpleRecord.class); + secondDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + // verify records in the current snapshot by branch + Dataset currentSnapshotResult = + spark.read().format("iceberg").option("branch", "branch").load(tableLocation); + List currentSnapshotRecords = + currentSnapshotResult.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(firstBatchRecords); + assertThat(currentSnapshotRecords) + .as("Current snapshot rows should match") + .isEqualTo(expectedRecords); + } + + @TestTemplate + public void testSnapshotSelectionByBranchAndTagFails() throws IOException { + String tableLocation = temp.resolve("iceberg-table").toFile().toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, properties, tableLocation); + + // produce the first snapshot + List firstBatchRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset firstDf = spark.createDataFrame(firstBatchRecords, SimpleRecord.class); + firstDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + table.manageSnapshots().createBranch("branch", table.currentSnapshot().snapshotId()).commit(); + table.manageSnapshots().createTag("tag", table.currentSnapshot().snapshotId()).commit(); + + assertThatThrownBy( + () -> + spark + .read() + .format("iceberg") + .option(SparkReadOptions.TAG, "tag") + .option(SparkReadOptions.BRANCH, "branch") + .load(tableLocation) + .show()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageStartingWith("Can specify only one of snapshot-id"); + } + + @TestTemplate + public void testSnapshotSelectionByTimestampAndBranchOrTagFails() throws IOException { + String tableLocation = temp.resolve("iceberg-table").toFile().toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, properties, tableLocation); + + List firstBatchRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset firstDf = spark.createDataFrame(firstBatchRecords, SimpleRecord.class); + firstDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + long timestamp = System.currentTimeMillis(); + table.manageSnapshots().createBranch("branch", table.currentSnapshot().snapshotId()).commit(); + table.manageSnapshots().createTag("tag", table.currentSnapshot().snapshotId()).commit(); + + assertThatThrownBy( + () -> + spark + .read() + .format("iceberg") + .option(SparkReadOptions.AS_OF_TIMESTAMP, timestamp) + .option(SparkReadOptions.BRANCH, "branch") + .load(tableLocation) + .show()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageStartingWith("Can specify only one of snapshot-id"); + + assertThatThrownBy( + () -> + spark + .read() + .format("iceberg") + .option(SparkReadOptions.AS_OF_TIMESTAMP, timestamp) + .option(SparkReadOptions.TAG, "tag") + .load(tableLocation) + .show()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageStartingWith("Can specify only one of snapshot-id"); + } + + @TestTemplate + public void testSnapshotSelectionByBranchWithSchemaChange() throws IOException { + String tableLocation = temp.resolve("iceberg-table").toFile().toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, properties, tableLocation); + + // produce the first snapshot + List firstBatchRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset firstDf = spark.createDataFrame(firstBatchRecords, SimpleRecord.class); + firstDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + table.manageSnapshots().createBranch("branch", table.currentSnapshot().snapshotId()).commit(); + + Dataset branchSnapshotResult = + spark.read().format("iceberg").option("branch", "branch").load(tableLocation); + List branchSnapshotRecords = + branchSnapshotResult.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(firstBatchRecords); + assertThat(branchSnapshotRecords) + .as("Current snapshot rows should match") + .isEqualTo(expectedRecords); + + // Deleting a column to indicate schema change + table.updateSchema().deleteColumn("data").commit(); + + // The data should not have the deleted column + assertThat( + spark + .read() + .format("iceberg") + .option("branch", "branch") + .load(tableLocation) + .orderBy("id") + .collectAsList()) + .containsExactly(RowFactory.create(1), RowFactory.create(2), RowFactory.create(3)); + + // re-introducing the column should not let the data re-appear + table.updateSchema().addColumn("data", Types.StringType.get()).commit(); + + assertThat( + spark + .read() + .format("iceberg") + .option("branch", "branch") + .load(tableLocation) + .orderBy("id") + .as(Encoders.bean(SimpleRecord.class)) + .collectAsList()) + .containsExactly( + new SimpleRecord(1, null), new SimpleRecord(2, null), new SimpleRecord(3, null)); + } + + @TestTemplate + public void testWritingToBranchAfterSchemaChange() throws IOException { + String tableLocation = temp.resolve("iceberg-table").toFile().toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, properties, tableLocation); + + // produce the first snapshot + List firstBatchRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset firstDf = spark.createDataFrame(firstBatchRecords, SimpleRecord.class); + firstDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + table.manageSnapshots().createBranch("branch", table.currentSnapshot().snapshotId()).commit(); + + Dataset branchSnapshotResult = + spark.read().format("iceberg").option("branch", "branch").load(tableLocation); + List branchSnapshotRecords = + branchSnapshotResult.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(firstBatchRecords); + assertThat(branchSnapshotRecords) + .as("Current snapshot rows should match") + .isEqualTo(expectedRecords); + + // Deleting and add a new column of the same type to indicate schema change + table.updateSchema().deleteColumn("data").addColumn("zip", Types.IntegerType.get()).commit(); + + assertThat( + spark + .read() + .format("iceberg") + .option("branch", "branch") + .load(tableLocation) + .orderBy("id") + .collectAsList()) + .containsExactly( + RowFactory.create(1, null), RowFactory.create(2, null), RowFactory.create(3, null)); + + // writing new records into the branch should work with the new column + List records = + Lists.newArrayList( + RowFactory.create(4, 12345), RowFactory.create(5, 54321), RowFactory.create(6, 67890)); + + Dataset dataFrame = + spark.createDataFrame( + records, + SparkSchemaUtil.convert( + new Schema( + optional(1, "id", Types.IntegerType.get()), + optional(2, "zip", Types.IntegerType.get())))); + dataFrame + .select("id", "zip") + .write() + .format("iceberg") + .option("branch", "branch") + .mode("append") + .save(tableLocation); + + assertThat( + spark + .read() + .format("iceberg") + .option("branch", "branch") + .load(tableLocation) + .collectAsList()) + .hasSize(6) + .contains( + RowFactory.create(1, null), RowFactory.create(2, null), RowFactory.create(3, null)) + .containsAll(records); + } + + @TestTemplate + public void testSnapshotSelectionByTagWithSchemaChange() throws IOException { + String tableLocation = temp.resolve("iceberg-table").toFile().toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, properties, tableLocation); + + // produce the first snapshot + List firstBatchRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset firstDf = spark.createDataFrame(firstBatchRecords, SimpleRecord.class); + firstDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + table.manageSnapshots().createTag("tag", table.currentSnapshot().snapshotId()).commit(); + + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(firstBatchRecords); + + Dataset tagSnapshotResult = + spark.read().format("iceberg").option("tag", "tag").load(tableLocation); + List tagSnapshotRecords = + tagSnapshotResult.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + assertThat(tagSnapshotRecords) + .as("Current snapshot rows should match") + .isEqualTo(expectedRecords); + + // Deleting a column to indicate schema change + table.updateSchema().deleteColumn("data").commit(); + + // The data should have the deleted column as it was captured in an earlier snapshot. + Dataset deletedColumnTagSnapshotResult = + spark.read().format("iceberg").option("tag", "tag").load(tableLocation); + List deletedColumnTagSnapshotRecords = + deletedColumnTagSnapshotResult + .orderBy("id") + .as(Encoders.bean(SimpleRecord.class)) + .collectAsList(); + assertThat(deletedColumnTagSnapshotRecords) + .as("Current snapshot rows should match") + .isEqualTo(expectedRecords); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkAggregates.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkAggregates.java new file mode 100644 index 000000000000..06b68b77e680 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkAggregates.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.Map; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkAggregates; +import org.apache.spark.sql.connector.expressions.FieldReference; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.aggregate.Count; +import org.apache.spark.sql.connector.expressions.aggregate.CountStar; +import org.apache.spark.sql.connector.expressions.aggregate.Max; +import org.apache.spark.sql.connector.expressions.aggregate.Min; +import org.junit.jupiter.api.Test; + +public class TestSparkAggregates { + + @Test + public void testAggregates() { + Map attrMap = Maps.newHashMap(); + attrMap.put("id", "id"); + attrMap.put("`i.d`", "i.d"); + attrMap.put("`i``d`", "i`d"); + attrMap.put("`d`.b.`dd```", "d.b.dd`"); + attrMap.put("a.`aa```.c", "a.aa`.c"); + + attrMap.forEach( + (quoted, unquoted) -> { + NamedReference namedReference = FieldReference.apply(quoted); + + Max max = new Max(namedReference); + Expression expectedMax = Expressions.max(unquoted); + Expression actualMax = SparkAggregates.convert(max); + assertThat(String.valueOf(actualMax)) + .as("Max must match") + .isEqualTo(expectedMax.toString()); + + Min min = new Min(namedReference); + Expression expectedMin = Expressions.min(unquoted); + Expression actualMin = SparkAggregates.convert(min); + assertThat(String.valueOf(actualMin)) + .as("Min must match") + .isEqualTo(expectedMin.toString()); + + Count count = new Count(namedReference, false); + Expression expectedCount = Expressions.count(unquoted); + Expression actualCount = SparkAggregates.convert(count); + assertThat(String.valueOf(actualCount)) + .as("Count must match") + .isEqualTo(expectedCount.toString()); + + Count countDistinct = new Count(namedReference, true); + Expression convertedCountDistinct = SparkAggregates.convert(countDistinct); + assertThat(convertedCountDistinct).as("Count Distinct is converted to null").isNull(); + + CountStar countStar = new CountStar(); + Expression expectedCountStar = Expressions.countStar(); + Expression actualCountStar = SparkAggregates.convert(countStar); + assertThat(String.valueOf(actualCountStar)) + .as("CountStar must match") + .isEqualTo(expectedCountStar.toString()); + }); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkAppenderFactory.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkAppenderFactory.java new file mode 100644 index 000000000000..0664400c7911 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkAppenderFactory.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.List; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileAppenderFactory; +import org.apache.iceberg.io.TestAppenderFactory; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.util.ArrayUtil; +import org.apache.iceberg.util.StructLikeSet; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; + +public class TestSparkAppenderFactory extends TestAppenderFactory { + + private final StructType sparkType = SparkSchemaUtil.convert(SCHEMA); + + @Override + protected FileAppenderFactory createAppenderFactory( + List equalityFieldIds, Schema eqDeleteSchema, Schema posDeleteRowSchema) { + return SparkAppenderFactory.builderFor(table, table.schema(), sparkType) + .equalityFieldIds(ArrayUtil.toIntArray(equalityFieldIds)) + .eqDeleteRowSchema(eqDeleteSchema) + .posDelRowSchema(posDeleteRowSchema) + .build(); + } + + @Override + protected InternalRow createRow(Integer id, String data) { + InternalRow row = new GenericInternalRow(2); + row.update(0, id); + row.update(1, UTF8String.fromString(data)); + return row; + } + + @Override + protected StructLikeSet expectedRowSet(Iterable rows) { + StructLikeSet set = StructLikeSet.create(table.schema().asStruct()); + for (InternalRow row : rows) { + InternalRowWrapper wrapper = new InternalRowWrapper(sparkType, table.schema().asStruct()); + set.add(wrapper.wrap(row)); + } + return set; + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkCatalog.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkCatalog.java new file mode 100644 index 000000000000..1f266380cdc1 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkCatalog.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Map; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.connector.catalog.FunctionCatalog; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.SupportsNamespaces; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.catalog.ViewCatalog; + +public class TestSparkCatalog< + T extends TableCatalog & FunctionCatalog & SupportsNamespaces & ViewCatalog> + extends SparkSessionCatalog { + + private static final Map TABLE_MAP = Maps.newHashMap(); + + public static void setTable(Identifier ident, Table table) { + Preconditions.checkArgument( + !TABLE_MAP.containsKey(ident), "Cannot set " + ident + ". It is already set"); + TABLE_MAP.put(ident, table); + } + + @Override + public Table loadTable(Identifier ident) throws NoSuchTableException { + if (TABLE_MAP.containsKey(ident)) { + return TABLE_MAP.get(ident); + } + + TableIdentifier tableIdentifier = Spark3Util.identifierToTableIdentifier(ident); + Namespace namespace = tableIdentifier.namespace(); + + TestTables.TestTable table = TestTables.load(tableIdentifier.toString()); + if (table == null && namespace.equals(Namespace.of("default"))) { + table = TestTables.load(tableIdentifier.name()); + } + + return new SparkTable(table, false); + } + + public static void clearTables() { + TABLE_MAP.clear(); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkCatalogCacheExpiration.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkCatalogCacheExpiration.java new file mode 100644 index 000000000000..2a9bbca40f94 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkCatalogCacheExpiration.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.Map; +import org.apache.iceberg.CachingCatalog; +import org.apache.iceberg.CatalogProperties; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.catalog.Catalog; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.iceberg.spark.TestBaseWithCatalog; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.TestTemplate; + +public class TestSparkCatalogCacheExpiration extends TestBaseWithCatalog { + + private static final Map SESSION_CATALOG_CONFIG = + ImmutableMap.of( + "type", + "hadoop", + "default-namespace", + "default", + CatalogProperties.CACHE_ENABLED, + "true", + CatalogProperties.CACHE_EXPIRATION_INTERVAL_MS, + "3000"); + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + public static Object[][] parameters() { + return new Object[][] { + {"spark_catalog", SparkSessionCatalog.class.getName(), SESSION_CATALOG_CONFIG}, + }; + } + + private static String asSqlConfCatalogKeyFor(String catalog, String configKey) { + // configKey is empty when the catalog's class is being defined + if (configKey.isEmpty()) { + return String.format("spark.sql.catalog.%s", catalog); + } else { + return String.format("spark.sql.catalog.%s.%s", catalog, configKey); + } + } + + // Add more catalogs to the spark session, so we only need to start spark one time for multiple + // different catalog configuration tests. + @BeforeAll + public static void beforeClass() { + // Catalog - expiration_disabled: Catalog with caching on and expiration disabled. + ImmutableMap.of( + "", + "org.apache.iceberg.spark.SparkCatalog", + "type", + "hive", + CatalogProperties.CACHE_ENABLED, + "true", + CatalogProperties.CACHE_EXPIRATION_INTERVAL_MS, + "-1") + .forEach((k, v) -> spark.conf().set(asSqlConfCatalogKeyFor("expiration_disabled", k), v)); + + // Catalog - cache_disabled_implicitly: Catalog that does not cache, as the cache expiration + // interval is 0. + ImmutableMap.of( + "", + "org.apache.iceberg.spark.SparkCatalog", + "type", + "hive", + CatalogProperties.CACHE_ENABLED, + "true", + CatalogProperties.CACHE_EXPIRATION_INTERVAL_MS, + "0") + .forEach( + (k, v) -> spark.conf().set(asSqlConfCatalogKeyFor("cache_disabled_implicitly", k), v)); + } + + @TestTemplate + public void testSparkSessionCatalogWithExpirationEnabled() { + SparkSessionCatalog sparkCatalog = sparkSessionCatalog(); + assertThat(sparkCatalog) + .extracting("icebergCatalog") + .extracting("cacheEnabled") + .isEqualTo(true); + + assertThat(sparkCatalog) + .extracting("icebergCatalog") + .extracting("icebergCatalog") + .isInstanceOfSatisfying( + Catalog.class, + icebergCatalog -> { + assertThat(icebergCatalog) + .isExactlyInstanceOf(CachingCatalog.class) + .extracting("expirationIntervalMillis") + .isEqualTo(3000L); + }); + } + + @TestTemplate + public void testCacheEnabledAndExpirationDisabled() { + SparkCatalog sparkCatalog = getSparkCatalog("expiration_disabled"); + assertThat(sparkCatalog).extracting("cacheEnabled").isEqualTo(true); + + assertThat(sparkCatalog) + .extracting("icebergCatalog") + .isInstanceOfSatisfying( + CachingCatalog.class, + icebergCatalog -> { + assertThat(icebergCatalog).extracting("expirationIntervalMillis").isEqualTo(-1L); + }); + } + + @TestTemplate + public void testCacheDisabledImplicitly() { + SparkCatalog sparkCatalog = getSparkCatalog("cache_disabled_implicitly"); + assertThat(sparkCatalog).extracting("cacheEnabled").isEqualTo(false); + + assertThat(sparkCatalog) + .extracting("icebergCatalog") + .isInstanceOfSatisfying( + Catalog.class, + icebergCatalog -> assertThat(icebergCatalog).isNotInstanceOf(CachingCatalog.class)); + } + + private SparkSessionCatalog sparkSessionCatalog() { + TableCatalog catalog = + (TableCatalog) spark.sessionState().catalogManager().catalog("spark_catalog"); + return (SparkSessionCatalog) catalog; + } + + private SparkCatalog getSparkCatalog(String catalog) { + return (SparkCatalog) spark.sessionState().catalogManager().catalog(catalog); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkCatalogHadoopOverrides.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkCatalogHadoopOverrides.java new file mode 100644 index 000000000000..c031f2991fed --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkCatalogHadoopOverrides.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.apache.hadoop.conf.Configurable; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.KryoHelpers; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.Table; +import org.apache.iceberg.TestHelpers; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.CatalogTestBase; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestSparkCatalogHadoopOverrides extends CatalogTestBase { + + private static final String CONFIG_TO_OVERRIDE = "fs.s3a.buffer.dir"; + // prepend "hadoop." so that the test base formats SQLConf correctly + // as `spark.sql.catalogs..hadoop. + private static final String HADOOP_PREFIXED_CONFIG_TO_OVERRIDE = "hadoop." + CONFIG_TO_OVERRIDE; + private static final String CONFIG_OVERRIDE_VALUE = "/tmp-overridden"; + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + public static Object[][] parameters() { + return new Object[][] { + { + "testhive", + SparkCatalog.class.getName(), + ImmutableMap.of( + "type", + "hive", + "default-namespace", + "default", + HADOOP_PREFIXED_CONFIG_TO_OVERRIDE, + CONFIG_OVERRIDE_VALUE) + }, + { + "testhadoop", + SparkCatalog.class.getName(), + ImmutableMap.of("type", "hadoop", HADOOP_PREFIXED_CONFIG_TO_OVERRIDE, CONFIG_OVERRIDE_VALUE) + }, + { + "spark_catalog", + SparkSessionCatalog.class.getName(), + ImmutableMap.of( + "type", + "hive", + "default-namespace", + "default", + HADOOP_PREFIXED_CONFIG_TO_OVERRIDE, + CONFIG_OVERRIDE_VALUE) + } + }; + } + + @BeforeEach + public void createTable() { + sql("CREATE TABLE IF NOT EXISTS %s (id bigint) USING iceberg", tableName(tableIdent.name())); + } + + @AfterEach + public void dropTable() { + sql("DROP TABLE IF EXISTS %s", tableName(tableIdent.name())); + } + + @TestTemplate + public void testTableFromCatalogHasOverrides() throws Exception { + Table table = getIcebergTableFromSparkCatalog(); + Configuration conf = ((Configurable) table.io()).getConf(); + String actualCatalogOverride = conf.get(CONFIG_TO_OVERRIDE, "/whammies"); + assertThat(actualCatalogOverride) + .as( + "Iceberg tables from spark should have the overridden hadoop configurations from the spark config") + .isEqualTo(CONFIG_OVERRIDE_VALUE); + } + + @TestTemplate + public void ensureRoundTripSerializedTableRetainsHadoopConfig() throws Exception { + Table table = getIcebergTableFromSparkCatalog(); + Configuration originalConf = ((Configurable) table.io()).getConf(); + String actualCatalogOverride = originalConf.get(CONFIG_TO_OVERRIDE, "/whammies"); + assertThat(actualCatalogOverride) + .as( + "Iceberg tables from spark should have the overridden hadoop configurations from the spark config") + .isEqualTo(CONFIG_OVERRIDE_VALUE); + + // Now convert to SerializableTable and ensure overridden property is still present. + Table serializableTable = SerializableTableWithSize.copyOf(table); + Table kryoSerializedTable = + KryoHelpers.roundTripSerialize(SerializableTableWithSize.copyOf(table)); + Configuration configFromKryoSerde = ((Configurable) kryoSerializedTable.io()).getConf(); + String kryoSerializedCatalogOverride = configFromKryoSerde.get(CONFIG_TO_OVERRIDE, "/whammies"); + assertThat(kryoSerializedCatalogOverride) + .as( + "Tables serialized with Kryo serialization should retain overridden hadoop configuration properties") + .isEqualTo(CONFIG_OVERRIDE_VALUE); + + // Do the same for Java based serde + Table javaSerializedTable = TestHelpers.roundTripSerialize(serializableTable); + Configuration configFromJavaSerde = ((Configurable) javaSerializedTable.io()).getConf(); + String javaSerializedCatalogOverride = configFromJavaSerde.get(CONFIG_TO_OVERRIDE, "/whammies"); + assertThat(javaSerializedCatalogOverride) + .as( + "Tables serialized with Java serialization should retain overridden hadoop configuration properties") + .isEqualTo(CONFIG_OVERRIDE_VALUE); + } + + @SuppressWarnings("ThrowSpecificity") + private Table getIcebergTableFromSparkCatalog() throws Exception { + Identifier identifier = Identifier.of(tableIdent.namespace().levels(), tableIdent.name()); + TableCatalog catalog = + (TableCatalog) spark.sessionState().catalogManager().catalog(catalogName); + SparkTable sparkTable = (SparkTable) catalog.loadTable(identifier); + return sparkTable.table(); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkDVWriters.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkDVWriters.java new file mode 100644 index 000000000000..dfc693d3094d --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkDVWriters.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.List; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileWriterFactory; +import org.apache.iceberg.io.TestDVWriters; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.util.ArrayUtil; +import org.apache.iceberg.util.StructLikeSet; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; + +public class TestSparkDVWriters extends TestDVWriters { + + @Override + protected FileWriterFactory newWriterFactory( + Schema dataSchema, + List equalityFieldIds, + Schema equalityDeleteRowSchema, + Schema positionDeleteRowSchema) { + return SparkFileWriterFactory.builderFor(table) + .dataSchema(table.schema()) + .dataFileFormat(dataFormat()) + .deleteFileFormat(dataFormat()) + .equalityFieldIds(ArrayUtil.toIntArray(equalityFieldIds)) + .equalityDeleteRowSchema(equalityDeleteRowSchema) + .positionDeleteRowSchema(positionDeleteRowSchema) + .build(); + } + + @Override + protected InternalRow toRow(Integer id, String data) { + InternalRow row = new GenericInternalRow(2); + row.update(0, id); + row.update(1, UTF8String.fromString(data)); + return row; + } + + @Override + protected StructLikeSet toSet(Iterable rows) { + StructLikeSet set = StructLikeSet.create(table.schema().asStruct()); + StructType sparkType = SparkSchemaUtil.convert(table.schema()); + for (InternalRow row : rows) { + InternalRowWrapper wrapper = new InternalRowWrapper(sparkType, table.schema().asStruct()); + set.add(wrapper.wrap(row)); + } + return set; + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkDataFile.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkDataFile.java new file mode 100644 index 000000000000..182b1ef8f5af --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkDataFile.java @@ -0,0 +1,338 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.stream.Collectors; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileMetadata; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.ManifestFiles; +import org.apache.iceberg.ManifestReader; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Metrics; +import org.apache.iceberg.PartitionField; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.RowDelta; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.UpdatePartitionSpec; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkDataFile; +import org.apache.iceberg.spark.SparkDeleteFile; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.types.Conversions; +import org.apache.iceberg.types.Types; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.ColumnName; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructType; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +public class TestSparkDataFile { + + private static final HadoopTables TABLES = new HadoopTables(new Configuration()); + private static final Schema SCHEMA = + new Schema( + required(100, "id", Types.LongType.get()), + optional(101, "data", Types.StringType.get()), + required(102, "b", Types.BooleanType.get()), + optional(103, "i", Types.IntegerType.get()), + required(104, "l", Types.LongType.get()), + optional(105, "f", Types.FloatType.get()), + required(106, "d", Types.DoubleType.get()), + optional(107, "date", Types.DateType.get()), + required(108, "ts", Types.TimestampType.withZone()), + required(109, "tsntz", Types.TimestampType.withoutZone()), + required(110, "s", Types.StringType.get()), + optional(113, "bytes", Types.BinaryType.get()), + required(114, "dec_9_0", Types.DecimalType.of(9, 0)), + required(115, "dec_11_2", Types.DecimalType.of(11, 2)), + required(116, "dec_38_10", Types.DecimalType.of(38, 10)) // maximum precision + ); + private static final PartitionSpec SPEC = + PartitionSpec.builderFor(SCHEMA) + .identity("b") + .bucket("i", 2) + .identity("l") + .identity("f") + .identity("d") + .identity("date") + .hour("ts") + .identity("ts") + .identity("tsntz") + .truncate("s", 2) + .identity("bytes") + .bucket("dec_9_0", 2) + .bucket("dec_11_2", 2) + .bucket("dec_38_10", 2) + .build(); + + private static SparkSession spark; + private static JavaSparkContext sparkContext = null; + + @BeforeAll + public static void startSpark() { + TestSparkDataFile.spark = SparkSession.builder().master("local[2]").getOrCreate(); + TestSparkDataFile.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + } + + @AfterAll + public static void stopSpark() { + SparkSession currentSpark = TestSparkDataFile.spark; + TestSparkDataFile.spark = null; + TestSparkDataFile.sparkContext = null; + currentSpark.stop(); + } + + @TempDir private File tableDir; + private String tableLocation = null; + + @BeforeEach + public void setupTableLocation() throws Exception { + this.tableLocation = tableDir.toURI().toString(); + } + + @Test + public void testValueConversion() throws IOException { + Table table = + TABLES.create(SCHEMA, PartitionSpec.unpartitioned(), Maps.newHashMap(), tableLocation); + checkSparkContentFiles(table); + } + + @Test + public void testValueConversionPartitionedTable() throws IOException { + Table table = TABLES.create(SCHEMA, SPEC, Maps.newHashMap(), tableLocation); + checkSparkContentFiles(table); + } + + @Test + public void testValueConversionWithEmptyStats() throws IOException { + Map props = Maps.newHashMap(); + props.put(TableProperties.DEFAULT_WRITE_METRICS_MODE, "none"); + Table table = TABLES.create(SCHEMA, SPEC, props, tableLocation); + checkSparkContentFiles(table); + } + + private void checkSparkContentFiles(Table table) throws IOException { + Iterable rows = RandomData.generateSpark(table.schema(), 200, 0); + JavaRDD rdd = sparkContext.parallelize(Lists.newArrayList(rows)); + Dataset df = + spark.internalCreateDataFrame( + JavaRDD.toRDD(rdd), SparkSchemaUtil.convert(table.schema()), false); + + df.write().format("iceberg").mode("append").save(tableLocation); + + table.refresh(); + + PartitionSpec dataFilesSpec = table.spec(); + + List manifests = table.currentSnapshot().allManifests(table.io()); + assertThat(manifests).hasSize(1); + + List dataFiles = Lists.newArrayList(); + try (ManifestReader reader = ManifestFiles.read(manifests.get(0), table.io())) { + for (DataFile dataFile : reader) { + checkDataFile(dataFile.copy(), DataFiles.builder(dataFilesSpec).copy(dataFile).build()); + dataFiles.add(dataFile.copy()); + } + } + + UpdatePartitionSpec updateSpec = table.updateSpec(); + for (PartitionField field : dataFilesSpec.fields()) { + updateSpec.removeField(field.name()); + } + updateSpec.commit(); + + List positionDeleteFiles = Lists.newArrayList(); + List equalityDeleteFiles = Lists.newArrayList(); + + RowDelta rowDelta = table.newRowDelta(); + + for (DataFile dataFile : dataFiles) { + DeleteFile positionDeleteFile = createPositionDeleteFile(table, dataFile); + positionDeleteFiles.add(positionDeleteFile); + rowDelta.addDeletes(positionDeleteFile); + } + + DeleteFile equalityDeleteFile1 = createEqualityDeleteFile(table); + equalityDeleteFiles.add(equalityDeleteFile1); + rowDelta.addDeletes(equalityDeleteFile1); + + DeleteFile equalityDeleteFile2 = createEqualityDeleteFile(table); + equalityDeleteFiles.add(equalityDeleteFile2); + rowDelta.addDeletes(equalityDeleteFile2); + + rowDelta.commit(); + + Dataset dataFileDF = spark.read().format("iceberg").load(tableLocation + "#data_files"); + List sparkDataFiles = shuffleColumns(dataFileDF).collectAsList(); + assertThat(sparkDataFiles).hasSameSizeAs(dataFiles); + + Types.StructType dataFileType = DataFile.getType(dataFilesSpec.partitionType()); + StructType sparkDataFileType = sparkDataFiles.get(0).schema(); + SparkDataFile dataFileWrapper = new SparkDataFile(dataFileType, sparkDataFileType); + + for (int i = 0; i < dataFiles.size(); i++) { + checkDataFile(dataFiles.get(i), dataFileWrapper.wrap(sparkDataFiles.get(i))); + } + + Dataset positionDeleteFileDF = + spark.read().format("iceberg").load(tableLocation + "#delete_files").where("content = 1"); + List sparkPositionDeleteFiles = shuffleColumns(positionDeleteFileDF).collectAsList(); + assertThat(sparkPositionDeleteFiles).hasSameSizeAs(positionDeleteFiles); + + Types.StructType positionDeleteFileType = DataFile.getType(dataFilesSpec.partitionType()); + StructType sparkPositionDeleteFileType = sparkPositionDeleteFiles.get(0).schema(); + SparkDeleteFile positionDeleteFileWrapper = + new SparkDeleteFile(positionDeleteFileType, sparkPositionDeleteFileType); + + for (int i = 0; i < positionDeleteFiles.size(); i++) { + checkDeleteFile( + positionDeleteFiles.get(i), + positionDeleteFileWrapper.wrap(sparkPositionDeleteFiles.get(i))); + } + + Dataset equalityDeleteFileDF = + spark.read().format("iceberg").load(tableLocation + "#delete_files").where("content = 2"); + List sparkEqualityDeleteFiles = shuffleColumns(equalityDeleteFileDF).collectAsList(); + assertThat(sparkEqualityDeleteFiles).hasSameSizeAs(equalityDeleteFiles); + + Types.StructType equalityDeleteFileType = DataFile.getType(table.spec().partitionType()); + StructType sparkEqualityDeleteFileType = sparkEqualityDeleteFiles.get(0).schema(); + SparkDeleteFile equalityDeleteFileWrapper = + new SparkDeleteFile(equalityDeleteFileType, sparkEqualityDeleteFileType); + + for (int i = 0; i < equalityDeleteFiles.size(); i++) { + checkDeleteFile( + equalityDeleteFiles.get(i), + equalityDeleteFileWrapper.wrap(sparkEqualityDeleteFiles.get(i))); + } + } + + private Dataset shuffleColumns(Dataset df) { + List columns = + Arrays.stream(df.columns()).map(ColumnName::new).collect(Collectors.toList()); + Collections.shuffle(columns); + return df.select(columns.toArray(new Column[0])); + } + + private void checkDataFile(DataFile expected, DataFile actual) { + assertThat(expected.equalityFieldIds()).isNull(); + assertThat(actual.equalityFieldIds()).isNull(); + checkContentFile(expected, actual); + checkStructLike(expected.partition(), actual.partition()); + } + + private void checkDeleteFile(DeleteFile expected, DeleteFile actual) { + assertThat(expected.equalityFieldIds()).isEqualTo(actual.equalityFieldIds()); + checkContentFile(expected, actual); + checkStructLike(expected.partition(), actual.partition()); + } + + private void checkContentFile(ContentFile expected, ContentFile actual) { + assertThat(actual.content()).isEqualTo(expected.content()); + assertThat(actual.location()).isEqualTo(expected.location()); + assertThat(actual.format()).isEqualTo(expected.format()); + assertThat(actual.recordCount()).isEqualTo(expected.recordCount()); + assertThat(actual.fileSizeInBytes()).isEqualTo(expected.fileSizeInBytes()); + assertThat(actual.valueCounts()).isEqualTo(expected.valueCounts()); + assertThat(actual.nullValueCounts()).isEqualTo(expected.nullValueCounts()); + assertThat(actual.nanValueCounts()).isEqualTo(expected.nanValueCounts()); + assertThat(actual.lowerBounds()).isEqualTo(expected.lowerBounds()); + assertThat(actual.upperBounds()).isEqualTo(expected.upperBounds()); + assertThat(actual.keyMetadata()).isEqualTo(expected.keyMetadata()); + assertThat(actual.splitOffsets()).isEqualTo(expected.splitOffsets()); + assertThat(actual.sortOrderId()).isEqualTo(expected.sortOrderId()); + } + + private void checkStructLike(StructLike expected, StructLike actual) { + assertThat(actual.size()).isEqualTo(expected.size()); + for (int i = 0; i < expected.size(); i++) { + assertThat(actual.get(i, Object.class)).isEqualTo(expected.get(i, Object.class)); + } + } + + private DeleteFile createPositionDeleteFile(Table table, DataFile dataFile) { + PartitionSpec spec = table.specs().get(dataFile.specId()); + return FileMetadata.deleteFileBuilder(spec) + .ofPositionDeletes() + .withPath("/path/to/pos-deletes-" + UUID.randomUUID() + ".parquet") + .withFileSizeInBytes(dataFile.fileSizeInBytes() / 4) + .withPartition(dataFile.partition()) + .withRecordCount(2) + .withMetrics( + new Metrics( + 2L, + null, // no column sizes + null, // no value counts + null, // no null counts + null, // no NaN counts + ImmutableMap.of( + MetadataColumns.DELETE_FILE_PATH.fieldId(), + Conversions.toByteBuffer(Types.StringType.get(), dataFile.location())), + ImmutableMap.of( + MetadataColumns.DELETE_FILE_PATH.fieldId(), + Conversions.toByteBuffer(Types.StringType.get(), dataFile.location())))) + .withEncryptionKeyMetadata(ByteBuffer.allocate(4).putInt(35)) + .build(); + } + + private DeleteFile createEqualityDeleteFile(Table table) { + return FileMetadata.deleteFileBuilder(table.spec()) + .ofEqualityDeletes(3, 4) + .withPath("/path/to/eq-deletes-" + UUID.randomUUID() + ".parquet") + .withFileSizeInBytes(250) + .withRecordCount(1) + .withSortOrder(SortOrder.unsorted()) + .withEncryptionKeyMetadata(ByteBuffer.allocate(4).putInt(35)) + .build(); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkDataWrite.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkDataWrite.java new file mode 100644 index 000000000000..fb2b312bed97 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkDataWrite.java @@ -0,0 +1,742 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.TableProperties.SPARK_WRITE_PARTITIONED_FANOUT_ENABLED; +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assumptions.assumeThat; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.AppendFiles; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.ManifestFiles; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SnapshotRef; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.exceptions.CommitStateUnknownException; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestSparkDataWrite { + private static final Configuration CONF = new Configuration(); + + @Parameter(index = 0) + private FileFormat format; + + @Parameter(index = 1) + private String branch; + + private static SparkSession spark = null; + private static final Schema SCHEMA = + new Schema( + optional(1, "id", Types.IntegerType.get()), optional(2, "data", Types.StringType.get())); + + @TempDir private Path temp; + + @Parameters(name = "format = {0}, branch = {1}") + public static Object[][] parameters() { + return new Object[][] { + new Object[] {FileFormat.PARQUET, null}, + new Object[] {FileFormat.PARQUET, "main"}, + new Object[] {FileFormat.PARQUET, "testBranch"}, + new Object[] {FileFormat.AVRO, null}, + new Object[] {FileFormat.ORC, "testBranch"} + }; + } + + @BeforeAll + public static void startSpark() { + TestSparkDataWrite.spark = SparkSession.builder().master("local[2]").getOrCreate(); + } + + @AfterEach + public void clearSourceCache() { + ManualSource.clearTables(); + } + + @AfterAll + public static void stopSpark() { + SparkSession currentSpark = TestSparkDataWrite.spark; + TestSparkDataWrite.spark = null; + currentSpark.stop(); + } + + @TestTemplate + public void testBasicWrite() throws IOException { + File parent = temp.resolve(format.toString()).toFile(); + File location = new File(parent, "test"); + String targetLocation = locationWithBranch(location); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + List expected = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + + Dataset df = spark.createDataFrame(expected, SimpleRecord.class); + // TODO: incoming columns must be ordered according to the table's schema + df.select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(location.toString()); + + createBranch(table); + table.refresh(); + + Dataset result = spark.read().format("iceberg").load(targetLocation); + + List actual = + result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + assertThat(actual).as("Number of rows should match").hasSameSizeAs(expected); + assertThat(actual).as("Result rows should match").isEqualTo(expected); + for (ManifestFile manifest : + SnapshotUtil.latestSnapshot(table, branch).allManifests(table.io())) { + for (DataFile file : ManifestFiles.read(manifest, table.io())) { + // TODO: avro not support split + if (!format.equals(FileFormat.AVRO)) { + assertThat(file.splitOffsets()).as("Split offsets not present").isNotNull(); + } + assertThat(file.recordCount()).as("Should have reported record count as 1").isEqualTo(1); + // TODO: append more metric info + if (format.equals(FileFormat.PARQUET)) { + assertThat(file.columnSizes()).as("Column sizes metric not present").isNotNull(); + assertThat(file.valueCounts()).as("Counts metric not present").isNotNull(); + assertThat(file.nullValueCounts()).as("Null value counts metric not present").isNotNull(); + assertThat(file.lowerBounds()).as("Lower bounds metric not present").isNotNull(); + assertThat(file.upperBounds()).as("Upper bounds metric not present").isNotNull(); + } + } + } + } + + @TestTemplate + public void testAppend() throws IOException { + File parent = temp.resolve(format.toString()).toFile(); + File location = new File(parent, "test"); + String targetLocation = locationWithBranch(location); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + + List expected = + Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c"), + new SimpleRecord(4, "a"), + new SimpleRecord(5, "b"), + new SimpleRecord(6, "c")); + + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + + df.select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(location.toString()); + + createBranch(table); + + df.withColumn("id", df.col("id").plus(3)) + .select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(targetLocation); + + table.refresh(); + + Dataset result = spark.read().format("iceberg").load(targetLocation); + + List actual = + result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + assertThat(actual).as("Number of rows should match").hasSameSizeAs(expected); + assertThat(actual).as("Result rows should match").isEqualTo(expected); + } + + @TestTemplate + public void testEmptyOverwrite() throws IOException { + File parent = temp.resolve(format.toString()).toFile(); + File location = new File(parent, "test"); + String targetLocation = locationWithBranch(location); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("id").build(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + + List expected = records; + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + + df.select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(location.toString()); + + createBranch(table); + + Dataset empty = spark.createDataFrame(ImmutableList.of(), SimpleRecord.class); + empty + .select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Overwrite) + .option("overwrite-mode", "dynamic") + .save(targetLocation); + + table.refresh(); + + Dataset result = spark.read().format("iceberg").load(targetLocation); + + List actual = + result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + assertThat(actual).as("Number of rows should match").hasSameSizeAs(expected); + assertThat(actual).as("Result rows should match").isEqualTo(expected); + } + + @TestTemplate + public void testOverwrite() throws IOException { + File parent = temp.resolve(format.toString()).toFile(); + File location = new File(parent, "test"); + String targetLocation = locationWithBranch(location); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("id").build(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + + List expected = + Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "a"), + new SimpleRecord(3, "c"), + new SimpleRecord(4, "b"), + new SimpleRecord(6, "c")); + + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + + df.select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(location.toString()); + + createBranch(table); + + // overwrite with 2*id to replace record 2, append 4 and 6 + df.withColumn("id", df.col("id").multiply(2)) + .select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Overwrite) + .option("overwrite-mode", "dynamic") + .save(targetLocation); + + table.refresh(); + + Dataset result = spark.read().format("iceberg").load(targetLocation); + + List actual = + result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + assertThat(actual).as("Number of rows should match").hasSameSizeAs(expected); + assertThat(actual).as("Result rows should match").isEqualTo(expected); + } + + @TestTemplate + public void testUnpartitionedOverwrite() throws IOException { + File parent = temp.resolve(format.toString()).toFile(); + File location = new File(parent, "test"); + String targetLocation = locationWithBranch(location); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + List expected = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + + Dataset df = spark.createDataFrame(expected, SimpleRecord.class); + + df.select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(location.toString()); + + createBranch(table); + + // overwrite with the same data; should not produce two copies + df.select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Overwrite) + .save(targetLocation); + + table.refresh(); + + Dataset result = spark.read().format("iceberg").load(targetLocation); + + List actual = + result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + assertThat(actual).as("Number of rows should match").hasSameSizeAs(expected); + assertThat(actual).as("Result rows should match").isEqualTo(expected); + } + + @TestTemplate + public void testUnpartitionedCreateWithTargetFileSizeViaTableProperties() throws IOException { + File parent = temp.resolve(format.toString()).toFile(); + File location = new File(parent, "test"); + String targetLocation = locationWithBranch(location); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + table + .updateProperties() + .set(TableProperties.WRITE_TARGET_FILE_SIZE_BYTES, "4") // ~4 bytes; low enough to trigger + .commit(); + + List expected = Lists.newArrayListWithCapacity(4000); + for (int i = 0; i < 4000; i++) { + expected.add(new SimpleRecord(i, "a")); + } + + Dataset df = spark.createDataFrame(expected, SimpleRecord.class); + + df.select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(location.toString()); + + createBranch(table); + table.refresh(); + + Dataset result = spark.read().format("iceberg").load(targetLocation); + + List actual = + result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + assertThat(actual).as("Number of rows should match").hasSameSizeAs(expected); + assertThat(actual).as("Result rows should match").isEqualTo(expected); + + List files = Lists.newArrayList(); + for (ManifestFile manifest : + SnapshotUtil.latestSnapshot(table, branch).allManifests(table.io())) { + for (DataFile file : ManifestFiles.read(manifest, table.io())) { + files.add(file); + } + } + + assertThat(files).as("Should have 4 DataFiles").hasSize(4); + assertThat(files.stream()) + .as("All DataFiles contain 1000 rows") + .allMatch(d -> d.recordCount() == 1000); + } + + @TestTemplate + public void testPartitionedCreateWithTargetFileSizeViaOption() throws IOException { + partitionedCreateWithTargetFileSizeViaOption(IcebergOptionsType.NONE); + } + + @TestTemplate + public void testPartitionedFanoutCreateWithTargetFileSizeViaOption() throws IOException { + partitionedCreateWithTargetFileSizeViaOption(IcebergOptionsType.TABLE); + } + + @TestTemplate + public void testPartitionedFanoutCreateWithTargetFileSizeViaOption2() throws IOException { + partitionedCreateWithTargetFileSizeViaOption(IcebergOptionsType.JOB); + } + + @TestTemplate + public void testWriteProjection() throws IOException { + assumeThat(spark.version()) + .as("Not supported in Spark 3; analysis requires all columns are present") + .startsWith("2"); + + File parent = temp.resolve(format.toString()).toFile(); + File location = new File(parent, "test"); + String targetLocation = locationWithBranch(location); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + List expected = + Lists.newArrayList( + new SimpleRecord(1, null), new SimpleRecord(2, null), new SimpleRecord(3, null)); + + Dataset df = spark.createDataFrame(expected, SimpleRecord.class); + + df.select("id") + .write() // select only id column + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(location.toString()); + + createBranch(table); + table.refresh(); + + Dataset result = spark.read().format("iceberg").load(targetLocation); + + List actual = + result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + assertThat(actual).as("Number of rows should match").hasSameSizeAs(expected); + assertThat(actual).as("Result rows should match").isEqualTo(expected); + } + + @TestTemplate + public void testWriteProjectionWithMiddle() throws IOException { + assumeThat(spark.version()) + .as("Not supported in Spark 3; analysis requires all columns are present") + .startsWith("2"); + + File parent = temp.resolve(format.toString()).toFile(); + File location = new File(parent, "test"); + String targetLocation = locationWithBranch(location); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Schema schema = + new Schema( + optional(1, "c1", Types.IntegerType.get()), + optional(2, "c2", Types.StringType.get()), + optional(3, "c3", Types.StringType.get())); + Table table = tables.create(schema, spec, location.toString()); + + List expected = + Lists.newArrayList( + new ThreeColumnRecord(1, null, "hello"), + new ThreeColumnRecord(2, null, "world"), + new ThreeColumnRecord(3, null, null)); + + Dataset df = spark.createDataFrame(expected, ThreeColumnRecord.class); + + df.select("c1", "c3") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(location.toString()); + + createBranch(table); + table.refresh(); + + Dataset result = spark.read().format("iceberg").load(targetLocation); + + List actual = + result.orderBy("c1").as(Encoders.bean(ThreeColumnRecord.class)).collectAsList(); + assertThat(actual).as("Number of rows should match").hasSameSizeAs(expected); + assertThat(actual).as("Result rows should match").isEqualTo(expected); + } + + @TestTemplate + public void testViewsReturnRecentResults() throws IOException { + File parent = temp.resolve(format.toString()).toFile(); + File location = new File(parent, "test"); + String targetLocation = locationWithBranch(location); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + tables.create(SCHEMA, spec, location.toString()); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + + df.select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(location.toString()); + + Table table = tables.load(location.toString()); + createBranch(table); + + Dataset query = spark.read().format("iceberg").load(targetLocation).where("id = 1"); + query.createOrReplaceTempView("tmp"); + + List actual1 = + spark.table("tmp").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + List expected1 = Lists.newArrayList(new SimpleRecord(1, "a")); + assertThat(actual1).as("Number of rows should match").hasSameSizeAs(expected1); + assertThat(actual1).as("Result rows should match").isEqualTo(expected1); + + df.select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(targetLocation); + + List actual2 = + spark.table("tmp").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + List expected2 = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(1, "a")); + assertThat(actual2).as("Number of rows should match").hasSameSizeAs(expected2); + assertThat(actual2).as("Result rows should match").isEqualTo(expected2); + } + + public void partitionedCreateWithTargetFileSizeViaOption(IcebergOptionsType option) + throws IOException { + File parent = temp.resolve(format.toString()).toFile(); + File location = new File(parent, "test"); + String targetLocation = locationWithBranch(location); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Map properties = + ImmutableMap.of( + TableProperties.WRITE_DISTRIBUTION_MODE, TableProperties.WRITE_DISTRIBUTION_MODE_NONE); + Table table = tables.create(SCHEMA, spec, properties, location.toString()); + + List expected = Lists.newArrayListWithCapacity(8000); + for (int i = 0; i < 2000; i++) { + expected.add(new SimpleRecord(i, "a")); + expected.add(new SimpleRecord(i, "b")); + expected.add(new SimpleRecord(i, "c")); + expected.add(new SimpleRecord(i, "d")); + } + + Dataset df = spark.createDataFrame(expected, SimpleRecord.class); + + switch (option) { + case NONE: + df.select("id", "data") + .sort("data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .option(SparkWriteOptions.TARGET_FILE_SIZE_BYTES, 4) // ~4 bytes; low enough to trigger + .save(location.toString()); + break; + case TABLE: + table.updateProperties().set(SPARK_WRITE_PARTITIONED_FANOUT_ENABLED, "true").commit(); + df.select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .option(SparkWriteOptions.TARGET_FILE_SIZE_BYTES, 4) // ~4 bytes; low enough to trigger + .save(location.toString()); + break; + case JOB: + df.select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .option(SparkWriteOptions.TARGET_FILE_SIZE_BYTES, 4) // ~4 bytes; low enough to trigger + .option(SparkWriteOptions.FANOUT_ENABLED, true) + .save(location.toString()); + break; + default: + break; + } + + createBranch(table); + table.refresh(); + + Dataset result = spark.read().format("iceberg").load(targetLocation); + + List actual = + result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + assertThat(actual).as("Number of rows should match").hasSameSizeAs(expected); + assertThat(actual).as("Result rows should match").isEqualTo(expected); + + List files = Lists.newArrayList(); + for (ManifestFile manifest : + SnapshotUtil.latestSnapshot(table, branch).allManifests(table.io())) { + for (DataFile file : ManifestFiles.read(manifest, table.io())) { + files.add(file); + } + } + assertThat(files).as("Should have 8 DataFiles").hasSize(8); + assertThat(files.stream()) + .as("All DataFiles contain 1000 rows") + .allMatch(d -> d.recordCount() == 1000); + } + + @TestTemplate + public void testCommitUnknownException() throws IOException { + File parent = temp.resolve(format.toString()).toFile(); + File location = new File(parent, "commitunknown"); + String targetLocation = locationWithBranch(location); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + + df.select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(location.toString()); + + createBranch(table); + table.refresh(); + + List records2 = + Lists.newArrayList( + new SimpleRecord(4, "d"), new SimpleRecord(5, "e"), new SimpleRecord(6, "f")); + + Dataset df2 = spark.createDataFrame(records2, SimpleRecord.class); + + AppendFiles append = table.newFastAppend(); + if (branch != null) { + append.toBranch(branch); + } + + AppendFiles spyAppend = spy(append); + doAnswer( + invocation -> { + append.commit(); + throw new CommitStateUnknownException(new RuntimeException("Datacenter on Fire")); + }) + .when(spyAppend) + .commit(); + + Table spyTable = spy(table); + when(spyTable.newAppend()).thenReturn(spyAppend); + SparkTable sparkTable = new SparkTable(spyTable, false); + + String manualTableName = "unknown_exception"; + ManualSource.setTable(manualTableName, sparkTable); + + // Although an exception is thrown here, write and commit have succeeded + assertThatThrownBy( + () -> + df2.select("id", "data") + .sort("data") + .write() + .format("org.apache.iceberg.spark.source.ManualSource") + .option(ManualSource.TABLE_NAME, manualTableName) + .mode(SaveMode.Append) + .save(targetLocation)) + .isInstanceOf(CommitStateUnknownException.class) + .hasMessageStartingWith("Datacenter on Fire"); + + // Since write and commit succeeded, the rows should be readable + Dataset result = spark.read().format("iceberg").load(targetLocation); + List actual = + result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + assertThat(actual).as("Number of rows should match").hasSize(records.size() + records2.size()); + assertThat(actual) + .describedAs("Result rows should match") + .containsExactlyInAnyOrder( + ImmutableList.builder() + .addAll(records) + .addAll(records2) + .build() + .toArray(new SimpleRecord[0])); + } + + public enum IcebergOptionsType { + NONE, + TABLE, + JOB + } + + private String locationWithBranch(File location) { + if (branch == null) { + return location.toString(); + } + + return location + "#branch_" + branch; + } + + private void createBranch(Table table) { + if (branch != null && !branch.equals(SnapshotRef.MAIN_BRANCH)) { + table.manageSnapshots().createBranch(branch, table.currentSnapshot().snapshotId()).commit(); + } + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkFileWriterFactory.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkFileWriterFactory.java new file mode 100644 index 000000000000..575e6658db22 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkFileWriterFactory.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.List; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileWriterFactory; +import org.apache.iceberg.io.TestFileWriterFactory; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.util.ArrayUtil; +import org.apache.iceberg.util.StructLikeSet; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; + +public class TestSparkFileWriterFactory extends TestFileWriterFactory { + + @Override + protected FileWriterFactory newWriterFactory( + Schema dataSchema, + List equalityFieldIds, + Schema equalityDeleteRowSchema, + Schema positionDeleteRowSchema) { + return SparkFileWriterFactory.builderFor(table) + .dataSchema(table.schema()) + .dataFileFormat(format()) + .deleteFileFormat(format()) + .equalityFieldIds(ArrayUtil.toIntArray(equalityFieldIds)) + .equalityDeleteRowSchema(equalityDeleteRowSchema) + .positionDeleteRowSchema(positionDeleteRowSchema) + .build(); + } + + @Override + protected InternalRow toRow(Integer id, String data) { + InternalRow row = new GenericInternalRow(2); + row.update(0, id); + row.update(1, UTF8String.fromString(data)); + return row; + } + + @Override + protected StructLikeSet toSet(Iterable rows) { + StructLikeSet set = StructLikeSet.create(table.schema().asStruct()); + StructType sparkType = SparkSchemaUtil.convert(table.schema()); + for (InternalRow row : rows) { + InternalRowWrapper wrapper = new InternalRowWrapper(sparkType, table.schema().asStruct()); + set.add(wrapper.wrap(row)); + } + return set; + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkMergingMetrics.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkMergingMetrics.java new file mode 100644 index 000000000000..29425398f395 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkMergingMetrics.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.io.File; +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.Files; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.Table; +import org.apache.iceberg.TestMergingMetrics; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.spark.sql.catalyst.InternalRow; + +public class TestSparkMergingMetrics extends TestMergingMetrics { + + @Override + protected FileAppender writeAndGetAppender(List records) throws IOException { + Table testTable = + new BaseTable(null, "dummy") { + @Override + public Map properties() { + return Collections.emptyMap(); + } + + @Override + public SortOrder sortOrder() { + return SortOrder.unsorted(); + } + + @Override + public PartitionSpec spec() { + return PartitionSpec.unpartitioned(); + } + }; + + File tempFile = File.createTempFile("junit", null, tempDir); + FileAppender appender = + SparkAppenderFactory.builderFor(testTable, SCHEMA, SparkSchemaUtil.convert(SCHEMA)) + .build() + .newAppender(Files.localOutput(tempFile), fileFormat); + try (FileAppender fileAppender = appender) { + records.stream() + .map(r -> new StructInternalRow(SCHEMA.asStruct()).setStruct(r)) + .forEach(fileAppender::add); + } + return appender; + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkMetadataColumns.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkMetadataColumns.java new file mode 100644 index 000000000000..230a660c0117 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkMetadataColumns.java @@ -0,0 +1,342 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.FORMAT_VERSION; +import static org.apache.iceberg.TableProperties.ORC_VECTORIZATION_ENABLED; +import static org.apache.iceberg.TableProperties.PARQUET_BATCH_SIZE; +import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES; +import static org.apache.iceberg.TableProperties.PARQUET_VECTORIZATION_ENABLED; +import static org.apache.spark.sql.functions.expr; +import static org.apache.spark.sql.functions.lit; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assumptions.assumeThat; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.TestHelpers; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.TestBase; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.types.StructType; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestSparkMetadataColumns extends TestBase { + + private static final String TABLE_NAME = "test_table"; + private static final Schema SCHEMA = + new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "category", Types.StringType.get()), + Types.NestedField.optional(3, "data", Types.StringType.get())); + private static final PartitionSpec SPEC = PartitionSpec.unpartitioned(); + private static final PartitionSpec UNKNOWN_SPEC = + TestHelpers.newExpectedSpecBuilder() + .withSchema(SCHEMA) + .withSpecId(1) + .addField("zero", 1, "id_zero") + .build(); + + @Parameters(name = "fileFormat = {0}, vectorized = {1}, formatVersion = {2}") + public static Object[][] parameters() { + return new Object[][] { + {FileFormat.PARQUET, false, 1}, + {FileFormat.PARQUET, true, 1}, + {FileFormat.PARQUET, false, 2}, + {FileFormat.PARQUET, true, 2}, + {FileFormat.AVRO, false, 1}, + {FileFormat.AVRO, false, 2}, + {FileFormat.ORC, false, 1}, + {FileFormat.ORC, true, 1}, + {FileFormat.ORC, false, 2}, + {FileFormat.ORC, true, 2}, + }; + } + + @TempDir private Path temp; + + @Parameter(index = 0) + private FileFormat fileFormat; + + @Parameter(index = 1) + private boolean vectorized; + + @Parameter(index = 2) + private int formatVersion; + + private Table table = null; + + @BeforeAll + public static void setupSpark() { + ImmutableMap config = + ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "cache-enabled", "true"); + spark + .conf() + .set("spark.sql.catalog.spark_catalog", "org.apache.iceberg.spark.source.TestSparkCatalog"); + config.forEach( + (key, value) -> spark.conf().set("spark.sql.catalog.spark_catalog." + key, value)); + } + + @BeforeEach + public void setupTable() throws IOException { + createAndInitTable(); + } + + @AfterEach + public void dropTable() { + TestTables.clearTables(); + } + + @TestTemplate + public void testSpecAndPartitionMetadataColumns() { + // TODO: support metadata structs in vectorized ORC reads + assumeThat(fileFormat).isNotEqualTo(FileFormat.ORC); + assumeThat(vectorized).isFalse(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", TABLE_NAME); + + table.refresh(); + table.updateSpec().addField("data").commit(); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", TABLE_NAME); + + table.refresh(); + table.updateSpec().addField(Expressions.bucket("category", 8)).commit(); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", TABLE_NAME); + + table.refresh(); + table.updateSpec().removeField("data").commit(); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", TABLE_NAME); + + table.refresh(); + table.updateSpec().renameField("category_bucket_8", "category_bucket_8_another_name").commit(); + + List expected = + ImmutableList.of( + row(0, row(null, null)), + row(1, row("b1", null)), + row(2, row("b1", 2)), + row(3, row(null, 2))); + assertEquals( + "Rows must match", + expected, + sql("SELECT _spec_id, _partition FROM %s ORDER BY _spec_id", TABLE_NAME)); + } + + @TestTemplate + public void testPartitionMetadataColumnWithManyColumns() { + List fields = + Lists.newArrayList(Types.NestedField.required(0, "id", Types.LongType.get())); + List additionalCols = + IntStream.range(1, 1010) + .mapToObj(i -> Types.NestedField.optional(i, "c" + i, Types.StringType.get())) + .collect(Collectors.toList()); + fields.addAll(additionalCols); + Schema manyColumnsSchema = new Schema(fields); + PartitionSpec spec = PartitionSpec.builderFor(manyColumnsSchema).identity("id").build(); + + TableOperations ops = ((HasTableOperations) table).operations(); + TableMetadata base = ops.current(); + ops.commit( + base, + base.updateSchema(manyColumnsSchema, manyColumnsSchema.highestFieldId()) + .updatePartitionSpec(spec)); + + Dataset df = + spark + .range(2) + .withColumns( + IntStream.range(1, 1010) + .boxed() + .collect(Collectors.toMap(i -> "c" + i, i -> expr("CAST(id as STRING)")))); + StructType sparkSchema = spark.table(TABLE_NAME).schema(); + spark + .createDataFrame(df.rdd(), sparkSchema) + .coalesce(1) + .write() + .format("iceberg") + .mode("append") + .save(TABLE_NAME); + + assertThat(spark.table(TABLE_NAME).select("*", "_partition").count()).isEqualTo(2); + List expected = + ImmutableList.of(row(row(0L), 0L, "0", "0", "0"), row(row(1L), 1L, "1", "1", "1")); + assertEquals( + "Rows must match", + expected, + sql("SELECT _partition, id, c999, c1000, c1001 FROM %s ORDER BY id", TABLE_NAME)); + } + + @TestTemplate + public void testPositionMetadataColumnWithMultipleRowGroups() throws NoSuchTableException { + assumeThat(fileFormat).isEqualTo(FileFormat.PARQUET); + + table.updateProperties().set(PARQUET_ROW_GROUP_SIZE_BYTES, "100").commit(); + + List ids = Lists.newArrayList(); + for (long id = 0L; id < 200L; id++) { + ids.add(id); + } + Dataset df = + spark + .createDataset(ids, Encoders.LONG()) + .withColumnRenamed("value", "id") + .withColumn("category", lit("hr")) + .withColumn("data", lit("ABCDEF")); + df.coalesce(1).writeTo(TABLE_NAME).append(); + + assertThat(spark.table(TABLE_NAME).count()).isEqualTo(200); + + List expectedRows = ids.stream().map(this::row).collect(Collectors.toList()); + assertEquals("Rows must match", expectedRows, sql("SELECT _pos FROM %s", TABLE_NAME)); + } + + @TestTemplate + public void testPositionMetadataColumnWithMultipleBatches() throws NoSuchTableException { + assumeThat(fileFormat).isEqualTo(FileFormat.PARQUET); + + table.updateProperties().set(PARQUET_BATCH_SIZE, "1000").commit(); + + List ids = Lists.newArrayList(); + for (long id = 0L; id < 7500L; id++) { + ids.add(id); + } + Dataset df = + spark + .createDataset(ids, Encoders.LONG()) + .withColumnRenamed("value", "id") + .withColumn("category", lit("hr")) + .withColumn("data", lit("ABCDEF")); + df.coalesce(1).writeTo(TABLE_NAME).append(); + + assertThat(spark.table(TABLE_NAME).count()).isEqualTo(7500); + + List expectedRows = ids.stream().map(this::row).collect(Collectors.toList()); + assertEquals("Rows must match", expectedRows, sql("SELECT _pos FROM %s", TABLE_NAME)); + } + + @TestTemplate + public void testPartitionMetadataColumnWithUnknownTransforms() { + // replace the table spec to include an unknown transform + TableOperations ops = ((HasTableOperations) table).operations(); + TableMetadata base = ops.current(); + ops.commit(base, base.updatePartitionSpec(UNKNOWN_SPEC)); + + assertThatThrownBy(() -> sql("SELECT _partition FROM %s", TABLE_NAME)) + .isInstanceOf(ValidationException.class) + .hasMessage("Cannot build table partition type, unknown transforms: [zero]"); + } + + @TestTemplate + public void testConflictingColumns() { + table + .updateSchema() + .addColumn(MetadataColumns.SPEC_ID.name(), Types.IntegerType.get()) + .addColumn(MetadataColumns.FILE_PATH.name(), Types.StringType.get()) + .commit(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1', -1, 'path/to/file')", TABLE_NAME); + + assertEquals( + "Rows must match", + ImmutableList.of(row(1L, "a1")), + sql("SELECT id, category FROM %s", TABLE_NAME)); + + assertThatThrownBy(() -> sql("SELECT * FROM %s", TABLE_NAME)) + .isInstanceOf(ValidationException.class) + .hasMessageStartingWith( + "Table column names conflict with names reserved for Iceberg metadata columns: [_spec_id, _file]."); + + table.refresh(); + + table + .updateSchema() + .renameColumn(MetadataColumns.SPEC_ID.name(), "_renamed" + MetadataColumns.SPEC_ID.name()) + .renameColumn( + MetadataColumns.FILE_PATH.name(), "_renamed" + MetadataColumns.FILE_PATH.name()) + .commit(); + + assertEquals( + "Rows must match", + ImmutableList.of(row(0, null, -1)), + sql("SELECT _spec_id, _partition, _renamed_spec_id FROM %s", TABLE_NAME)); + } + + private void createAndInitTable() throws IOException { + Map properties = Maps.newHashMap(); + properties.put(FORMAT_VERSION, String.valueOf(formatVersion)); + properties.put(DEFAULT_FILE_FORMAT, fileFormat.name()); + + switch (fileFormat) { + case PARQUET: + properties.put(PARQUET_VECTORIZATION_ENABLED, String.valueOf(vectorized)); + break; + case ORC: + properties.put(ORC_VECTORIZATION_ENABLED, String.valueOf(vectorized)); + break; + default: + Preconditions.checkState( + !vectorized, "File format %s does not support vectorized reads", fileFormat); + } + + this.table = + TestTables.create( + Files.createTempDirectory(temp, "junit").toFile(), + TABLE_NAME, + SCHEMA, + SPEC, + properties); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkPartitioningWriters.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkPartitioningWriters.java new file mode 100644 index 000000000000..979abd21e7f7 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkPartitioningWriters.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.List; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileWriterFactory; +import org.apache.iceberg.io.TestPartitioningWriters; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.util.ArrayUtil; +import org.apache.iceberg.util.StructLikeSet; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; + +public class TestSparkPartitioningWriters extends TestPartitioningWriters { + + @Override + protected FileWriterFactory newWriterFactory( + Schema dataSchema, + List equalityFieldIds, + Schema equalityDeleteRowSchema, + Schema positionDeleteRowSchema) { + return SparkFileWriterFactory.builderFor(table) + .dataSchema(table.schema()) + .dataFileFormat(format()) + .deleteFileFormat(format()) + .equalityFieldIds(ArrayUtil.toIntArray(equalityFieldIds)) + .equalityDeleteRowSchema(equalityDeleteRowSchema) + .positionDeleteRowSchema(positionDeleteRowSchema) + .build(); + } + + @Override + protected InternalRow toRow(Integer id, String data) { + InternalRow row = new GenericInternalRow(2); + row.update(0, id); + row.update(1, UTF8String.fromString(data)); + return row; + } + + @Override + protected StructLikeSet toSet(Iterable rows) { + StructLikeSet set = StructLikeSet.create(table.schema().asStruct()); + StructType sparkType = SparkSchemaUtil.convert(table.schema()); + for (InternalRow row : rows) { + InternalRowWrapper wrapper = new InternalRowWrapper(sparkType, table.schema().asStruct()); + set.add(wrapper.wrap(row)); + } + return set; + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkPlanningUtil.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkPlanningUtil.java new file mode 100644 index 000000000000..65c6790e5b49 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkPlanningUtil.java @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.when; + +import java.util.List; +import org.apache.iceberg.BaseScanTaskGroup; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataTask; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.MockFileScanTask; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.ScanTask; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.TestHelpers.Row; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.TestBaseWithCatalog; +import org.apache.iceberg.types.Types; +import org.junit.jupiter.api.TestTemplate; +import org.mockito.Mockito; + +public class TestSparkPlanningUtil extends TestBaseWithCatalog { + + private static final Schema SCHEMA = + new Schema( + required(1, "id", Types.IntegerType.get()), + required(2, "data", Types.StringType.get()), + required(3, "category", Types.StringType.get())); + private static final PartitionSpec SPEC_1 = + PartitionSpec.builderFor(SCHEMA).withSpecId(1).bucket("id", 16).identity("data").build(); + private static final PartitionSpec SPEC_2 = + PartitionSpec.builderFor(SCHEMA).withSpecId(2).identity("data").build(); + private static final List EXECUTOR_LOCATIONS = + ImmutableList.of("host1_exec1", "host1_exec2", "host1_exec3", "host2_exec1", "host2_exec2"); + + @TestTemplate + public void testFileScanTaskWithoutDeletes() { + List tasks = + ImmutableList.of( + new MockFileScanTask(mockDataFile(Row.of(1, "a")), SCHEMA, SPEC_1), + new MockFileScanTask(mockDataFile(Row.of(2, "b")), SCHEMA, SPEC_1), + new MockFileScanTask(mockDataFile(Row.of(3, "c")), SCHEMA, SPEC_1)); + ScanTaskGroup taskGroup = new BaseScanTaskGroup<>(tasks); + List> taskGroups = ImmutableList.of(taskGroup); + + String[][] locations = SparkPlanningUtil.assignExecutors(taskGroups, EXECUTOR_LOCATIONS); + + // should not assign executors if there are no deletes + assertThat(locations.length).isEqualTo(1); + assertThat(locations[0]).isEmpty(); + } + + @TestTemplate + public void testFileScanTaskWithDeletes() { + StructLike partition1 = Row.of("k2", null); + StructLike partition2 = Row.of("k1"); + List tasks = + ImmutableList.of( + new MockFileScanTask( + mockDataFile(partition1), mockDeleteFiles(1, partition1), SCHEMA, SPEC_1), + new MockFileScanTask( + mockDataFile(partition2), mockDeleteFiles(3, partition2), SCHEMA, SPEC_2), + new MockFileScanTask( + mockDataFile(partition1), mockDeleteFiles(2, partition1), SCHEMA, SPEC_1)); + ScanTaskGroup taskGroup = new BaseScanTaskGroup<>(tasks); + List> taskGroups = ImmutableList.of(taskGroup); + + String[][] locations = SparkPlanningUtil.assignExecutors(taskGroups, EXECUTOR_LOCATIONS); + + // should assign executors and handle different size of partitions + assertThat(locations.length).isEqualTo(1); + assertThat(locations[0].length).isGreaterThanOrEqualTo(1); + } + + @TestTemplate + public void testFileScanTaskWithUnpartitionedDeletes() { + List tasks1 = + ImmutableList.of( + new MockFileScanTask( + mockDataFile(Row.of()), + mockDeleteFiles(2, Row.of()), + SCHEMA, + PartitionSpec.unpartitioned()), + new MockFileScanTask( + mockDataFile(Row.of()), + mockDeleteFiles(2, Row.of()), + SCHEMA, + PartitionSpec.unpartitioned()), + new MockFileScanTask( + mockDataFile(Row.of()), + mockDeleteFiles(2, Row.of()), + SCHEMA, + PartitionSpec.unpartitioned())); + ScanTaskGroup taskGroup1 = new BaseScanTaskGroup<>(tasks1); + List tasks2 = + ImmutableList.of( + new MockFileScanTask( + mockDataFile(null), + mockDeleteFiles(2, null), + SCHEMA, + PartitionSpec.unpartitioned()), + new MockFileScanTask( + mockDataFile(null), + mockDeleteFiles(2, null), + SCHEMA, + PartitionSpec.unpartitioned()), + new MockFileScanTask( + mockDataFile(null), + mockDeleteFiles(2, null), + SCHEMA, + PartitionSpec.unpartitioned())); + ScanTaskGroup taskGroup2 = new BaseScanTaskGroup<>(tasks2); + List> taskGroups = ImmutableList.of(taskGroup1, taskGroup2); + + String[][] locations = SparkPlanningUtil.assignExecutors(taskGroups, EXECUTOR_LOCATIONS); + + // should not assign executors if the table is unpartitioned + assertThat(locations.length).isEqualTo(2); + assertThat(locations[0]).isEmpty(); + assertThat(locations[1]).isEmpty(); + } + + @TestTemplate + public void testDataTasks() { + List tasks = + ImmutableList.of( + new MockDataTask(mockDataFile(Row.of(1, "a"))), + new MockDataTask(mockDataFile(Row.of(2, "b"))), + new MockDataTask(mockDataFile(Row.of(3, "c")))); + ScanTaskGroup taskGroup = new BaseScanTaskGroup<>(tasks); + List> taskGroups = ImmutableList.of(taskGroup); + + String[][] locations = SparkPlanningUtil.assignExecutors(taskGroups, EXECUTOR_LOCATIONS); + + // should not assign executors for data tasks + assertThat(locations.length).isEqualTo(1); + assertThat(locations[0]).isEmpty(); + } + + @TestTemplate + public void testUnknownTasks() { + List tasks = ImmutableList.of(new UnknownScanTask(), new UnknownScanTask()); + ScanTaskGroup taskGroup = new BaseScanTaskGroup<>(tasks); + List> taskGroups = ImmutableList.of(taskGroup); + + String[][] locations = SparkPlanningUtil.assignExecutors(taskGroups, EXECUTOR_LOCATIONS); + + // should not assign executors for unknown tasks + assertThat(locations.length).isEqualTo(1); + assertThat(locations[0]).isEmpty(); + } + + private static DataFile mockDataFile(StructLike partition) { + DataFile file = Mockito.mock(DataFile.class); + when(file.partition()).thenReturn(partition); + return file; + } + + private static DeleteFile[] mockDeleteFiles(int count, StructLike partition) { + DeleteFile[] files = new DeleteFile[count]; + for (int index = 0; index < count; index++) { + files[index] = mockDeleteFile(partition); + } + return files; + } + + private static DeleteFile mockDeleteFile(StructLike partition) { + DeleteFile file = Mockito.mock(DeleteFile.class); + when(file.partition()).thenReturn(partition); + return file; + } + + private static class MockDataTask extends MockFileScanTask implements DataTask { + + MockDataTask(DataFile file) { + super(file); + } + + @Override + public PartitionSpec spec() { + return PartitionSpec.unpartitioned(); + } + + @Override + public CloseableIterable rows() { + throw new UnsupportedOperationException(); + } + } + + private static class UnknownScanTask implements ScanTask {} +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkPositionDeltaWriters.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkPositionDeltaWriters.java new file mode 100644 index 000000000000..9dc56abf9fb6 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkPositionDeltaWriters.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.List; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileWriterFactory; +import org.apache.iceberg.io.TestPositionDeltaWriters; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.util.ArrayUtil; +import org.apache.iceberg.util.StructLikeSet; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; + +public class TestSparkPositionDeltaWriters extends TestPositionDeltaWriters { + + @Override + protected FileWriterFactory newWriterFactory( + Schema dataSchema, + List equalityFieldIds, + Schema equalityDeleteRowSchema, + Schema positionDeleteRowSchema) { + return SparkFileWriterFactory.builderFor(table) + .dataSchema(table.schema()) + .dataFileFormat(format()) + .deleteFileFormat(format()) + .equalityFieldIds(ArrayUtil.toIntArray(equalityFieldIds)) + .equalityDeleteRowSchema(equalityDeleteRowSchema) + .positionDeleteRowSchema(positionDeleteRowSchema) + .build(); + } + + @Override + protected InternalRow toRow(Integer id, String data) { + InternalRow row = new GenericInternalRow(2); + row.update(0, id); + row.update(1, UTF8String.fromString(data)); + return row; + } + + @Override + protected StructLikeSet toSet(Iterable rows) { + StructLikeSet set = StructLikeSet.create(table.schema().asStruct()); + StructType sparkType = SparkSchemaUtil.convert(table.schema()); + for (InternalRow row : rows) { + InternalRowWrapper wrapper = new InternalRowWrapper(sparkType, table.schema().asStruct()); + set.add(wrapper.wrap(row)); + } + return set; + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReadMetrics.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReadMetrics.java new file mode 100644 index 000000000000..895861e95948 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReadMetrics.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.assertj.core.api.Assertions.assertThat; +import static scala.collection.JavaConverters.seqAsJavaListConverter; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.spark.TestBaseWithCatalog; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.execution.SparkPlan; +import org.apache.spark.sql.execution.metric.SQLMetric; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; +import scala.collection.JavaConverters; + +public class TestSparkReadMetrics extends TestBaseWithCatalog { + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testReadMetricsForV1Table() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id BIGINT) USING iceberg TBLPROPERTIES ('format-version'='1')", + tableName); + + spark.range(10000).coalesce(1).writeTo(tableName).append(); + spark.range(10001, 20000).coalesce(1).writeTo(tableName).append(); + + Dataset df = spark.sql(String.format("select * from %s where id < 10000", tableName)); + df.collect(); + + List sparkPlans = + seqAsJavaListConverter(df.queryExecution().executedPlan().collectLeaves()).asJava(); + Map metricsMap = + JavaConverters.mapAsJavaMapConverter(sparkPlans.get(0).metrics()).asJava(); + // Common + assertThat(metricsMap.get("totalPlanningDuration").value()).isNotEqualTo(0); + + // data manifests + assertThat(metricsMap.get("totalDataManifest").value()).isEqualTo(2); + assertThat(metricsMap.get("scannedDataManifests").value()).isEqualTo(2); + assertThat(metricsMap.get("skippedDataManifests").value()).isEqualTo(0); + + // data files + assertThat(metricsMap.get("resultDataFiles").value()).isEqualTo(1); + assertThat(metricsMap.get("skippedDataFiles").value()).isEqualTo(1); + assertThat(metricsMap.get("totalDataFileSize").value()).isNotEqualTo(0); + + // delete manifests + assertThat(metricsMap.get("totalDeleteManifests").value()).isEqualTo(0); + assertThat(metricsMap.get("scannedDeleteManifests").value()).isEqualTo(0); + assertThat(metricsMap.get("skippedDeleteManifests").value()).isEqualTo(0); + + // delete files + assertThat(metricsMap.get("totalDeleteFileSize").value()).isEqualTo(0); + assertThat(metricsMap.get("resultDeleteFiles").value()).isEqualTo(0); + assertThat(metricsMap.get("equalityDeleteFiles").value()).isEqualTo(0); + assertThat(metricsMap.get("indexedDeleteFiles").value()).isEqualTo(0); + assertThat(metricsMap.get("positionalDeleteFiles").value()).isEqualTo(0); + assertThat(metricsMap.get("skippedDeleteFiles").value()).isEqualTo(0); + } + + @TestTemplate + public void testReadMetricsForV2Table() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id BIGINT) USING iceberg TBLPROPERTIES ('format-version'='2')", + tableName); + + spark.range(10000).coalesce(1).writeTo(tableName).append(); + spark.range(10001, 20000).coalesce(1).writeTo(tableName).append(); + + Dataset df = spark.sql(String.format("select * from %s where id < 10000", tableName)); + df.collect(); + + List sparkPlans = + seqAsJavaListConverter(df.queryExecution().executedPlan().collectLeaves()).asJava(); + Map metricsMap = + JavaConverters.mapAsJavaMapConverter(sparkPlans.get(0).metrics()).asJava(); + + // Common + assertThat(metricsMap.get("totalPlanningDuration").value()).isNotEqualTo(0); + + // data manifests + assertThat(metricsMap.get("totalDataManifest").value()).isEqualTo(2); + assertThat(metricsMap.get("scannedDataManifests").value()).isEqualTo(2); + assertThat(metricsMap.get("skippedDataManifests").value()).isEqualTo(0); + + // data files + assertThat(metricsMap.get("resultDataFiles").value()).isEqualTo(1); + assertThat(metricsMap.get("skippedDataFiles").value()).isEqualTo(1); + assertThat(metricsMap.get("totalDataFileSize").value()).isNotEqualTo(0); + + // delete manifests + assertThat(metricsMap.get("totalDeleteManifests").value()).isEqualTo(0); + assertThat(metricsMap.get("scannedDeleteManifests").value()).isEqualTo(0); + assertThat(metricsMap.get("skippedDeleteManifests").value()).isEqualTo(0); + + // delete files + assertThat(metricsMap.get("totalDeleteFileSize").value()).isEqualTo(0); + assertThat(metricsMap.get("resultDeleteFiles").value()).isEqualTo(0); + assertThat(metricsMap.get("equalityDeleteFiles").value()).isEqualTo(0); + assertThat(metricsMap.get("indexedDeleteFiles").value()).isEqualTo(0); + assertThat(metricsMap.get("positionalDeleteFiles").value()).isEqualTo(0); + assertThat(metricsMap.get("skippedDeleteFiles").value()).isEqualTo(0); + } + + @TestTemplate + public void testDeleteMetrics() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id BIGINT)" + + " USING iceberg" + + " TBLPROPERTIES (\n" + + " 'write.delete.mode'='merge-on-read',\n" + + " 'write.update.mode'='merge-on-read',\n" + + " 'write.merge.mode'='merge-on-read',\n" + + " 'format-version'='2'\n" + + " )", + tableName); + + spark.range(10000).coalesce(1).writeTo(tableName).append(); + + spark.sql(String.format("DELETE FROM %s WHERE id = 1", tableName)).collect(); + Dataset df = spark.sql(String.format("SELECT * FROM %s", tableName)); + df.collect(); + + List sparkPlans = + seqAsJavaListConverter(df.queryExecution().executedPlan().collectLeaves()).asJava(); + Map metricsMap = + JavaConverters.mapAsJavaMapConverter(sparkPlans.get(0).metrics()).asJava(); + + // Common + assertThat(metricsMap.get("totalPlanningDuration").value()).isNotEqualTo(0); + + // data manifests + assertThat(metricsMap.get("totalDataManifest").value()).isEqualTo(1); + assertThat(metricsMap.get("scannedDataManifests").value()).isEqualTo(1); + assertThat(metricsMap.get("skippedDataManifests").value()).isEqualTo(0); + + // data files + assertThat(metricsMap.get("resultDataFiles").value()).isEqualTo(1); + assertThat(metricsMap.get("skippedDataFiles").value()).isEqualTo(0); + assertThat(metricsMap.get("totalDataFileSize").value()).isNotEqualTo(0); + + // delete manifests + assertThat(metricsMap.get("totalDeleteManifests").value()).isEqualTo(1); + assertThat(metricsMap.get("scannedDeleteManifests").value()).isEqualTo(1); + assertThat(metricsMap.get("skippedDeleteManifests").value()).isEqualTo(0); + + // delete files + assertThat(metricsMap.get("totalDeleteFileSize").value()).isNotEqualTo(0); + assertThat(metricsMap.get("resultDeleteFiles").value()).isEqualTo(1); + assertThat(metricsMap.get("equalityDeleteFiles").value()).isEqualTo(0); + assertThat(metricsMap.get("indexedDeleteFiles").value()).isEqualTo(1); + assertThat(metricsMap.get("positionalDeleteFiles").value()).isEqualTo(1); + assertThat(metricsMap.get("skippedDeleteFiles").value()).isEqualTo(0); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReadProjection.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReadProjection.java new file mode 100644 index 000000000000..99a327402d97 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReadProjection.java @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.Files.localOutput; +import static org.apache.iceberg.PlanningMode.DISTRIBUTED; +import static org.apache.iceberg.PlanningMode.LOCAL; +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.PlanningMode; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.data.GenericAppenderFactory; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkValueConverter; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestSparkReadProjection extends TestReadProjection { + + private static SparkSession spark = null; + + @Parameters(name = "format = {0}, vectorized = {1}, planningMode = {2}") + public static Object[][] parameters() { + return new Object[][] { + {FileFormat.PARQUET, false, LOCAL}, + {FileFormat.PARQUET, true, DISTRIBUTED}, + {FileFormat.AVRO, false, LOCAL}, + {FileFormat.ORC, false, DISTRIBUTED}, + {FileFormat.ORC, true, LOCAL} + }; + } + + @Parameter(index = 1) + private boolean vectorized; + + @Parameter(index = 2) + private PlanningMode planningMode; + + @BeforeAll + public static void startSpark() { + TestSparkReadProjection.spark = SparkSession.builder().master("local[2]").getOrCreate(); + ImmutableMap config = + ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "parquet-enabled", "true", + "cache-enabled", "false"); + spark + .conf() + .set("spark.sql.catalog.spark_catalog", "org.apache.iceberg.spark.source.TestSparkCatalog"); + config.forEach( + (key, value) -> spark.conf().set("spark.sql.catalog.spark_catalog." + key, value)); + } + + @AfterAll + public static void stopSpark() { + SparkSession currentSpark = TestSparkReadProjection.spark; + TestSparkReadProjection.spark = null; + currentSpark.stop(); + } + + @Override + protected Record writeAndRead(String desc, Schema writeSchema, Schema readSchema, Record record) + throws IOException { + File parent = new File(temp.toFile(), desc); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + assertThat(dataFolder.mkdirs()).as("mkdirs should succeed").isTrue(); + + File testFile = new File(dataFolder, format.addExtension(UUID.randomUUID().toString())); + + Table table = + TestTables.create( + location, + desc, + writeSchema, + PartitionSpec.unpartitioned(), + ImmutableMap.of( + TableProperties.DATA_PLANNING_MODE, planningMode.modeName(), + TableProperties.DELETE_PLANNING_MODE, planningMode.modeName())); + try { + // Important: use the table's schema for the rest of the test + // When tables are created, the column ids are reassigned. + Schema tableSchema = table.schema(); + + try (FileAppender writer = + new GenericAppenderFactory(tableSchema).newAppender(localOutput(testFile), format)) { + writer.add(record); + } + + DataFile file = + DataFiles.builder(PartitionSpec.unpartitioned()) + .withRecordCount(100) + .withFileSizeInBytes(testFile.length()) + .withPath(testFile.toString()) + .build(); + + table.newAppend().appendFile(file).commit(); + + // rewrite the read schema for the table's reassigned ids + Map idMapping = Maps.newHashMap(); + for (int id : allIds(writeSchema)) { + // translate each id to the original schema's column name, then to the new schema's id + String originalName = writeSchema.findColumnName(id); + idMapping.put(id, tableSchema.findField(originalName).fieldId()); + } + Schema expectedSchema = reassignIds(readSchema, idMapping); + + // Set the schema to the expected schema directly to simulate the table schema evolving + TestTables.replaceMetadata( + desc, TestTables.readMetadata(desc).updateSchema(expectedSchema, 100)); + + Dataset df = + spark + .read() + .format("org.apache.iceberg.spark.source.TestIcebergSource") + .option("iceberg.table.name", desc) + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(); + + return SparkValueConverter.convert(readSchema, df.collectAsList().get(0)); + + } finally { + TestTables.clearTables(); + } + } + + private List allIds(Schema schema) { + List ids = Lists.newArrayList(); + TypeUtil.visit( + schema, + new TypeUtil.SchemaVisitor() { + @Override + public Void field(Types.NestedField field, Void fieldResult) { + ids.add(field.fieldId()); + return null; + } + + @Override + public Void list(Types.ListType list, Void elementResult) { + ids.add(list.elementId()); + return null; + } + + @Override + public Void map(Types.MapType map, Void keyResult, Void valueResult) { + ids.add(map.keyId()); + ids.add(map.valueId()); + return null; + } + }); + return ids; + } + + private Schema reassignIds(Schema schema, Map idMapping) { + return new Schema( + TypeUtil.visit( + schema, + new TypeUtil.SchemaVisitor() { + private int mapId(int id) { + if (idMapping.containsKey(id)) { + return idMapping.get(id); + } + return 1000 + id; // make sure the new IDs don't conflict with reassignment + } + + @Override + public Type schema(Schema schema, Type structResult) { + return structResult; + } + + @Override + public Type struct(Types.StructType struct, List fieldResults) { + List newFields = + Lists.newArrayListWithExpectedSize(fieldResults.size()); + List fields = struct.fields(); + for (int i = 0; i < fields.size(); i += 1) { + Types.NestedField field = fields.get(i); + if (field.isOptional()) { + newFields.add( + optional(mapId(field.fieldId()), field.name(), fieldResults.get(i))); + } else { + newFields.add( + required(mapId(field.fieldId()), field.name(), fieldResults.get(i))); + } + } + return Types.StructType.of(newFields); + } + + @Override + public Type field(Types.NestedField field, Type fieldResult) { + return fieldResult; + } + + @Override + public Type list(Types.ListType list, Type elementResult) { + if (list.isElementOptional()) { + return Types.ListType.ofOptional(mapId(list.elementId()), elementResult); + } else { + return Types.ListType.ofRequired(mapId(list.elementId()), elementResult); + } + } + + @Override + public Type map(Types.MapType map, Type keyResult, Type valueResult) { + if (map.isValueOptional()) { + return Types.MapType.ofOptional( + mapId(map.keyId()), mapId(map.valueId()), keyResult, valueResult); + } else { + return Types.MapType.ofRequired( + mapId(map.keyId()), mapId(map.valueId()), keyResult, valueResult); + } + } + + @Override + public Type primitive(Type.PrimitiveType primitive) { + return primitive; + } + }) + .asNestedType() + .asStructType() + .fields()); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReaderDeletes.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReaderDeletes.java new file mode 100644 index 000000000000..d1ed1dc2b3cf --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReaderDeletes.java @@ -0,0 +1,690 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.METASTOREURIS; +import static org.apache.iceberg.spark.source.SparkSQLExecutionHelper.lastExecutedMetricValue; +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assumptions.assumeThat; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import java.util.Set; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.CatalogUtil; +import org.apache.iceberg.CombinedScanTask; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Files; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.PlanningMode; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.TestHelpers; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.data.DeleteReadTests; +import org.apache.iceberg.data.FileHelpers; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.InternalRecordWrapper; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.exceptions.AlreadyExistsException; +import org.apache.iceberg.hive.HiveCatalog; +import org.apache.iceberg.hive.TestHiveMetastore; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.parquet.ParquetSchemaUtil; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkStructLike; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.SparkParquetWriters; +import org.apache.iceberg.spark.source.metrics.NumDeletes; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.ArrayUtil; +import org.apache.iceberg.util.CharSequenceSet; +import org.apache.iceberg.util.Pair; +import org.apache.iceberg.util.StructLikeSet; +import org.apache.iceberg.util.TableScanUtil; +import org.apache.parquet.hadoop.ParquetFileWriter; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.types.StructType; +import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestSparkReaderDeletes extends DeleteReadTests { + + private static TestHiveMetastore metastore = null; + protected static SparkSession spark = null; + protected static HiveCatalog catalog = null; + + @Parameter(index = 2) + private boolean vectorized; + + @Parameter(index = 3) + private PlanningMode planningMode; + + @Parameters(name = "fileFormat = {0}, formatVersion = {1}, vectorized = {2}, planningMode = {3}") + public static Object[][] parameters() { + return new Object[][] { + new Object[] {FileFormat.PARQUET, 2, false, PlanningMode.DISTRIBUTED}, + new Object[] {FileFormat.PARQUET, 2, true, PlanningMode.LOCAL}, + new Object[] {FileFormat.ORC, 2, false, PlanningMode.DISTRIBUTED}, + new Object[] {FileFormat.AVRO, 2, false, PlanningMode.LOCAL}, + new Object[] {FileFormat.PARQUET, 3, false, PlanningMode.DISTRIBUTED}, + new Object[] {FileFormat.PARQUET, 3, true, PlanningMode.LOCAL}, + }; + } + + @BeforeAll + public static void startMetastoreAndSpark() { + metastore = new TestHiveMetastore(); + metastore.start(); + HiveConf hiveConf = metastore.hiveConf(); + + spark = + SparkSession.builder() + .master("local[2]") + .config("spark.appStateStore.asyncTracking.enable", false) + .config("spark.ui.liveUpdate.period", 0) + .config(SQLConf.PARTITION_OVERWRITE_MODE().key(), "dynamic") + .config("spark.hadoop." + METASTOREURIS.varname, hiveConf.get(METASTOREURIS.varname)) + .enableHiveSupport() + .getOrCreate(); + + catalog = + (HiveCatalog) + CatalogUtil.loadCatalog( + HiveCatalog.class.getName(), "hive", ImmutableMap.of(), hiveConf); + + try { + catalog.createNamespace(Namespace.of("default")); + } catch (AlreadyExistsException ignored) { + // the default namespace already exists. ignore the create error + } + } + + @AfterAll + public static void stopMetastoreAndSpark() throws Exception { + catalog = null; + metastore.stop(); + metastore = null; + spark.stop(); + spark = null; + } + + @AfterEach + @Override + public void cleanup() throws IOException { + super.cleanup(); + dropTable("test3"); + } + + @Override + protected Table createTable(String name, Schema schema, PartitionSpec spec) { + Table table = catalog.createTable(TableIdentifier.of("default", name), schema); + TableOperations ops = ((BaseTable) table).operations(); + TableMetadata meta = ops.current(); + ops.commit(meta, meta.upgradeToFormatVersion(formatVersion)); + table + .updateProperties() + .set(TableProperties.DEFAULT_FILE_FORMAT, format.name()) + .set(TableProperties.DATA_PLANNING_MODE, planningMode.modeName()) + .set(TableProperties.DELETE_PLANNING_MODE, planningMode.modeName()) + .set(TableProperties.FORMAT_VERSION, String.valueOf(formatVersion)) + .commit(); + if (format.equals(FileFormat.PARQUET) || format.equals(FileFormat.ORC)) { + String vectorizationEnabled = + format.equals(FileFormat.PARQUET) + ? TableProperties.PARQUET_VECTORIZATION_ENABLED + : TableProperties.ORC_VECTORIZATION_ENABLED; + String batchSize = + format.equals(FileFormat.PARQUET) + ? TableProperties.PARQUET_BATCH_SIZE + : TableProperties.ORC_BATCH_SIZE; + table.updateProperties().set(vectorizationEnabled, String.valueOf(vectorized)).commit(); + if (vectorized) { + // split 7 records to two batches to cover more code paths + table.updateProperties().set(batchSize, "4").commit(); + } + } + return table; + } + + @Override + protected void dropTable(String name) { + catalog.dropTable(TableIdentifier.of("default", name)); + } + + protected boolean countDeletes() { + return true; + } + + @Override + protected long deleteCount() { + return Long.parseLong(lastExecutedMetricValue(spark, NumDeletes.DISPLAY_STRING)); + } + + @Override + public StructLikeSet rowSet(String name, Table table, String... columns) { + return rowSet(name, table.schema().select(columns).asStruct(), columns); + } + + public StructLikeSet rowSet(String name, Types.StructType projection, String... columns) { + Dataset df = + spark + .read() + .format("iceberg") + .load(TableIdentifier.of("default", name).toString()) + .selectExpr(columns); + + StructLikeSet set = StructLikeSet.create(projection); + df.collectAsList() + .forEach( + row -> { + SparkStructLike rowWrapper = new SparkStructLike(projection); + set.add(rowWrapper.wrap(row)); + }); + + return set; + } + + @TestTemplate + public void testEqualityDeleteWithFilter() throws IOException { + String tableName = table.name().substring(table.name().lastIndexOf(".") + 1); + Schema deleteRowSchema = table.schema().select("data"); + Record dataDelete = GenericRecord.create(deleteRowSchema); + List dataDeletes = + Lists.newArrayList( + dataDelete.copy("data", "a"), // id = 29 + dataDelete.copy("data", "d"), // id = 89 + dataDelete.copy("data", "g") // id = 122 + ); + + DeleteFile eqDeletes = + FileHelpers.writeDeleteFile( + table, + Files.localOutput(File.createTempFile("junit", null, temp.toFile())), + TestHelpers.Row.of(0), + dataDeletes, + deleteRowSchema); + + table.newRowDelta().addDeletes(eqDeletes).commit(); + + Types.StructType projection = table.schema().select("*").asStruct(); + Dataset df = + spark + .read() + .format("iceberg") + .load(TableIdentifier.of("default", tableName).toString()) + .filter("data = 'a'") // select a deleted row + .selectExpr("*"); + + StructLikeSet actual = StructLikeSet.create(projection); + df.collectAsList() + .forEach( + row -> { + SparkStructLike rowWrapper = new SparkStructLike(projection); + actual.add(rowWrapper.wrap(row)); + }); + + assertThat(actual).as("Table should contain no rows").hasSize(0); + } + + @TestTemplate + public void testReadEqualityDeleteRows() throws IOException { + Schema deleteSchema1 = table.schema().select("data"); + Record dataDelete = GenericRecord.create(deleteSchema1); + List dataDeletes = + Lists.newArrayList( + dataDelete.copy("data", "a"), // id = 29 + dataDelete.copy("data", "d") // id = 89 + ); + + Schema deleteSchema2 = table.schema().select("id"); + Record idDelete = GenericRecord.create(deleteSchema2); + List idDeletes = + Lists.newArrayList( + idDelete.copy("id", 121), // id = 121 + idDelete.copy("id", 122) // id = 122 + ); + + DeleteFile eqDelete1 = + FileHelpers.writeDeleteFile( + table, + Files.localOutput(File.createTempFile("junit", null, temp.toFile())), + TestHelpers.Row.of(0), + dataDeletes, + deleteSchema1); + + DeleteFile eqDelete2 = + FileHelpers.writeDeleteFile( + table, + Files.localOutput(File.createTempFile("junit", null, temp.toFile())), + TestHelpers.Row.of(0), + idDeletes, + deleteSchema2); + + table.newRowDelta().addDeletes(eqDelete1).addDeletes(eqDelete2).commit(); + + StructLikeSet expectedRowSet = rowSetWithIds(29, 89, 121, 122); + + Types.StructType type = table.schema().asStruct(); + StructLikeSet actualRowSet = StructLikeSet.create(type); + + CloseableIterable tasks = + TableScanUtil.planTasks( + table.newScan().planFiles(), + TableProperties.METADATA_SPLIT_SIZE_DEFAULT, + TableProperties.SPLIT_LOOKBACK_DEFAULT, + TableProperties.SPLIT_OPEN_FILE_COST_DEFAULT); + + for (CombinedScanTask task : tasks) { + try (EqualityDeleteRowReader reader = + new EqualityDeleteRowReader(task, table, null, table.schema(), false)) { + while (reader.next()) { + actualRowSet.add( + new InternalRowWrapper( + SparkSchemaUtil.convert(table.schema()), table.schema().asStruct()) + .wrap(reader.get().copy())); + } + } + } + + assertThat(actualRowSet).as("should include 4 deleted row").hasSize(4); + assertThat(actualRowSet).as("deleted row should be matched").isEqualTo(expectedRowSet); + } + + @TestTemplate + public void testPosDeletesAllRowsInBatch() throws IOException { + // read.parquet.vectorization.batch-size is set to 4, so the 4 rows in the first batch are all + // deleted. + List> deletes = + Lists.newArrayList( + Pair.of(dataFile.location(), 0L), // id = 29 + Pair.of(dataFile.location(), 1L), // id = 43 + Pair.of(dataFile.location(), 2L), // id = 61 + Pair.of(dataFile.location(), 3L) // id = 89 + ); + + Pair posDeletes = + FileHelpers.writeDeleteFile( + table, + Files.localOutput(File.createTempFile("junit", null, temp.toFile())), + TestHelpers.Row.of(0), + deletes, + formatVersion); + + table + .newRowDelta() + .addDeletes(posDeletes.first()) + .validateDataFilesExist(posDeletes.second()) + .commit(); + + StructLikeSet expected = rowSetWithoutIds(table, records, 29, 43, 61, 89); + StructLikeSet actual = rowSet(tableName, table, "*"); + + assertThat(actual).as("Table should contain expected rows").isEqualTo(expected); + checkDeleteCount(4L); + } + + @TestTemplate + public void testPosDeletesWithDeletedColumn() throws IOException { + // read.parquet.vectorization.batch-size is set to 4, so the 4 rows in the first batch are all + // deleted. + List> deletes = + Lists.newArrayList( + Pair.of(dataFile.location(), 0L), // id = 29 + Pair.of(dataFile.location(), 1L), // id = 43 + Pair.of(dataFile.location(), 2L), // id = 61 + Pair.of(dataFile.location(), 3L) // id = 89 + ); + + Pair posDeletes = + FileHelpers.writeDeleteFile( + table, + Files.localOutput(File.createTempFile("junit", null, temp.toFile())), + TestHelpers.Row.of(0), + deletes, + formatVersion); + + table + .newRowDelta() + .addDeletes(posDeletes.first()) + .validateDataFilesExist(posDeletes.second()) + .commit(); + + StructLikeSet expected = expectedRowSet(29, 43, 61, 89); + StructLikeSet actual = + rowSet(tableName, PROJECTION_SCHEMA.asStruct(), "id", "data", "_deleted"); + + assertThat(actual).as("Table should contain expected row").isEqualTo(expected); + checkDeleteCount(4L); + } + + @TestTemplate + public void testEqualityDeleteWithDeletedColumn() throws IOException { + String tableName = table.name().substring(table.name().lastIndexOf(".") + 1); + Schema deleteRowSchema = table.schema().select("data"); + Record dataDelete = GenericRecord.create(deleteRowSchema); + List dataDeletes = + Lists.newArrayList( + dataDelete.copy("data", "a"), // id = 29 + dataDelete.copy("data", "d"), // id = 89 + dataDelete.copy("data", "g") // id = 122 + ); + + DeleteFile eqDeletes = + FileHelpers.writeDeleteFile( + table, + Files.localOutput(File.createTempFile("junit", null, temp.toFile())), + TestHelpers.Row.of(0), + dataDeletes, + deleteRowSchema); + + table.newRowDelta().addDeletes(eqDeletes).commit(); + + StructLikeSet expected = expectedRowSet(29, 89, 122); + StructLikeSet actual = + rowSet(tableName, PROJECTION_SCHEMA.asStruct(), "id", "data", "_deleted"); + + assertThat(actual).as("Table should contain expected row").isEqualTo(expected); + checkDeleteCount(3L); + } + + @TestTemplate + public void testMixedPosAndEqDeletesWithDeletedColumn() throws IOException { + Schema dataSchema = table.schema().select("data"); + Record dataDelete = GenericRecord.create(dataSchema); + List dataDeletes = + Lists.newArrayList( + dataDelete.copy("data", "a"), // id = 29 + dataDelete.copy("data", "d"), // id = 89 + dataDelete.copy("data", "g") // id = 122 + ); + + DeleteFile eqDeletes = + FileHelpers.writeDeleteFile( + table, + Files.localOutput(File.createTempFile("junit", null, temp.toFile())), + TestHelpers.Row.of(0), + dataDeletes, + dataSchema); + + List> deletes = + Lists.newArrayList( + Pair.of(dataFile.location(), 3L), // id = 89 + Pair.of(dataFile.location(), 5L) // id = 121 + ); + + Pair posDeletes = + FileHelpers.writeDeleteFile( + table, + Files.localOutput(File.createTempFile("junit", null, temp.toFile())), + TestHelpers.Row.of(0), + deletes, + formatVersion); + + table + .newRowDelta() + .addDeletes(eqDeletes) + .addDeletes(posDeletes.first()) + .validateDataFilesExist(posDeletes.second()) + .commit(); + + StructLikeSet expected = expectedRowSet(29, 89, 121, 122); + StructLikeSet actual = + rowSet(tableName, PROJECTION_SCHEMA.asStruct(), "id", "data", "_deleted"); + + assertThat(actual).as("Table should contain expected row").isEqualTo(expected); + checkDeleteCount(4L); + } + + @TestTemplate + public void testFilterOnDeletedMetadataColumn() throws IOException { + List> deletes = + Lists.newArrayList( + Pair.of(dataFile.location(), 0L), // id = 29 + Pair.of(dataFile.location(), 1L), // id = 43 + Pair.of(dataFile.location(), 2L), // id = 61 + Pair.of(dataFile.location(), 3L) // id = 89 + ); + + Pair posDeletes = + FileHelpers.writeDeleteFile( + table, + Files.localOutput(File.createTempFile("junit", null, temp.toFile())), + TestHelpers.Row.of(0), + deletes, + formatVersion); + + table + .newRowDelta() + .addDeletes(posDeletes.first()) + .validateDataFilesExist(posDeletes.second()) + .commit(); + + StructLikeSet expected = expectedRowSetWithNonDeletesOnly(29, 43, 61, 89); + + // get non-deleted rows + Dataset df = + spark + .read() + .format("iceberg") + .load(TableIdentifier.of("default", tableName).toString()) + .select("id", "data", "_deleted") + .filter("_deleted = false"); + + Types.StructType projection = PROJECTION_SCHEMA.asStruct(); + StructLikeSet actual = StructLikeSet.create(projection); + df.collectAsList() + .forEach( + row -> { + SparkStructLike rowWrapper = new SparkStructLike(projection); + actual.add(rowWrapper.wrap(row)); + }); + + assertThat(actual).as("Table should contain expected row").isEqualTo(expected); + + StructLikeSet expectedDeleted = expectedRowSetWithDeletesOnly(29, 43, 61, 89); + + // get deleted rows + df = + spark + .read() + .format("iceberg") + .load(TableIdentifier.of("default", tableName).toString()) + .select("id", "data", "_deleted") + .filter("_deleted = true"); + + StructLikeSet actualDeleted = StructLikeSet.create(projection); + df.collectAsList() + .forEach( + row -> { + SparkStructLike rowWrapper = new SparkStructLike(projection); + actualDeleted.add(rowWrapper.wrap(row)); + }); + + assertThat(actualDeleted).as("Table should contain expected row").isEqualTo(expectedDeleted); + } + + @TestTemplate + public void testIsDeletedColumnWithoutDeleteFile() { + StructLikeSet expected = expectedRowSet(); + StructLikeSet actual = + rowSet(tableName, PROJECTION_SCHEMA.asStruct(), "id", "data", "_deleted"); + assertThat(actual).as("Table should contain expected row").isEqualTo(expected); + checkDeleteCount(0L); + } + + @TestTemplate + public void testPosDeletesOnParquetFileWithMultipleRowGroups() throws IOException { + assumeThat(format).isEqualTo(FileFormat.PARQUET); + + String tblName = "test3"; + Table tbl = createTable(tblName, SCHEMA, PartitionSpec.unpartitioned()); + + List fileSplits = Lists.newArrayList(); + StructType sparkSchema = SparkSchemaUtil.convert(SCHEMA); + Configuration conf = new Configuration(); + File testFile = File.createTempFile("junit", null, temp.toFile()); + assertThat(testFile.delete()).as("Delete should succeed").isTrue(); + Path testFilePath = new Path(testFile.getAbsolutePath()); + + // Write a Parquet file with more than one row group + ParquetFileWriter parquetFileWriter = + new ParquetFileWriter(conf, ParquetSchemaUtil.convert(SCHEMA, "test3Schema"), testFilePath); + parquetFileWriter.start(); + for (int i = 0; i < 2; i += 1) { + File split = File.createTempFile("junit", null, temp.toFile()); + assertThat(split.delete()).as("Delete should succeed").isTrue(); + Path splitPath = new Path(split.getAbsolutePath()); + fileSplits.add(splitPath); + try (FileAppender writer = + Parquet.write(Files.localOutput(split)) + .createWriterFunc(msgType -> SparkParquetWriters.buildWriter(sparkSchema, msgType)) + .schema(SCHEMA) + .overwrite() + .build()) { + Iterable records = RandomData.generateSpark(SCHEMA, 100, 34 * i + 37); + writer.addAll(records); + } + parquetFileWriter.appendFile( + org.apache.parquet.hadoop.util.HadoopInputFile.fromPath(splitPath, conf)); + } + parquetFileWriter.end( + ParquetFileWriter.mergeMetadataFiles(fileSplits, conf) + .getFileMetaData() + .getKeyValueMetaData()); + + // Add the file to the table + DataFile dataFile = + DataFiles.builder(PartitionSpec.unpartitioned()) + .withInputFile(org.apache.iceberg.hadoop.HadoopInputFile.fromPath(testFilePath, conf)) + .withFormat("parquet") + .withRecordCount(200) + .build(); + tbl.newAppend().appendFile(dataFile).commit(); + + // Add positional deletes to the table + List> deletes = + Lists.newArrayList( + Pair.of(dataFile.location(), 97L), + Pair.of(dataFile.location(), 98L), + Pair.of(dataFile.location(), 99L), + Pair.of(dataFile.location(), 101L), + Pair.of(dataFile.location(), 103L), + Pair.of(dataFile.location(), 107L), + Pair.of(dataFile.location(), 109L)); + Pair posDeletes = + FileHelpers.writeDeleteFile( + table, + Files.localOutput(File.createTempFile("junit", null, temp.toFile())), + deletes, + formatVersion); + tbl.newRowDelta() + .addDeletes(posDeletes.first()) + .validateDataFilesExist(posDeletes.second()) + .commit(); + + assertThat(rowSet(tblName, tbl, "*")).hasSize(193); + } + + private static final Schema PROJECTION_SCHEMA = + new Schema( + required(1, "id", Types.IntegerType.get()), + required(2, "data", Types.StringType.get()), + MetadataColumns.IS_DELETED); + + private static StructLikeSet expectedRowSet(int... idsToRemove) { + return expectedRowSet(false, false, idsToRemove); + } + + private static StructLikeSet expectedRowSetWithDeletesOnly(int... idsToRemove) { + return expectedRowSet(false, true, idsToRemove); + } + + private static StructLikeSet expectedRowSetWithNonDeletesOnly(int... idsToRemove) { + return expectedRowSet(true, false, idsToRemove); + } + + private static StructLikeSet expectedRowSet( + boolean removeDeleted, boolean removeNonDeleted, int... idsToRemove) { + Set deletedIds = Sets.newHashSet(ArrayUtil.toIntList(idsToRemove)); + List records = recordsWithDeletedColumn(); + // mark rows deleted + records.forEach( + record -> { + if (deletedIds.contains(record.getField("id"))) { + record.setField(MetadataColumns.IS_DELETED.name(), true); + } + }); + + records.removeIf(record -> deletedIds.contains(record.getField("id")) && removeDeleted); + records.removeIf(record -> !deletedIds.contains(record.getField("id")) && removeNonDeleted); + + StructLikeSet set = StructLikeSet.create(PROJECTION_SCHEMA.asStruct()); + records.forEach( + record -> set.add(new InternalRecordWrapper(PROJECTION_SCHEMA.asStruct()).wrap(record))); + + return set; + } + + @NotNull + private static List recordsWithDeletedColumn() { + List records = Lists.newArrayList(); + + // records all use IDs that are in bucket id_bucket=0 + GenericRecord record = GenericRecord.create(PROJECTION_SCHEMA); + records.add(record.copy("id", 29, "data", "a", "_deleted", false)); + records.add(record.copy("id", 43, "data", "b", "_deleted", false)); + records.add(record.copy("id", 61, "data", "c", "_deleted", false)); + records.add(record.copy("id", 89, "data", "d", "_deleted", false)); + records.add(record.copy("id", 100, "data", "e", "_deleted", false)); + records.add(record.copy("id", 121, "data", "f", "_deleted", false)); + records.add(record.copy("id", 122, "data", "g", "_deleted", false)); + return records; + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReaderWithBloomFilter.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReaderWithBloomFilter.java new file mode 100644 index 000000000000..baf7fa8f88a2 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReaderWithBloomFilter.java @@ -0,0 +1,377 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.METASTOREURIS; +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT_DEFAULT; +import static org.apache.iceberg.TableProperties.PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX; +import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES; +import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.Closeable; +import java.io.File; +import java.io.IOException; +import java.math.BigDecimal; +import java.nio.file.Path; +import java.time.LocalDate; +import java.util.List; +import java.util.Map; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.CatalogUtil; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Files; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.TestHelpers.Row; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.data.GenericAppenderFactory; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.exceptions.AlreadyExistsException; +import org.apache.iceberg.hive.HiveCatalog; +import org.apache.iceberg.hive.TestHiveMetastore; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkValueConverter; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.SparkSession; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestSparkReaderWithBloomFilter { + + protected String tableName = null; + protected Table table = null; + protected List records = null; + protected DataFile dataFile = null; + + private static TestHiveMetastore metastore = null; + protected static SparkSession spark = null; + protected static HiveCatalog catalog = null; + + @Parameter(index = 0) + protected boolean vectorized; + + @Parameter(index = 1) + protected boolean useBloomFilter; + + // Schema passed to create tables + public static final Schema SCHEMA = + new Schema( + Types.NestedField.required(1, "id", Types.IntegerType.get()), + Types.NestedField.required(2, "id_long", Types.LongType.get()), + Types.NestedField.required(3, "id_double", Types.DoubleType.get()), + Types.NestedField.required(4, "id_float", Types.FloatType.get()), + Types.NestedField.required(5, "id_string", Types.StringType.get()), + Types.NestedField.optional(6, "id_boolean", Types.BooleanType.get()), + Types.NestedField.optional(7, "id_date", Types.DateType.get()), + Types.NestedField.optional(8, "id_int_decimal", Types.DecimalType.of(8, 2)), + Types.NestedField.optional(9, "id_long_decimal", Types.DecimalType.of(14, 2)), + Types.NestedField.optional(10, "id_fixed_decimal", Types.DecimalType.of(31, 2))); + + private static final int INT_MIN_VALUE = 30; + private static final int INT_MAX_VALUE = 329; + private static final int INT_VALUE_COUNT = INT_MAX_VALUE - INT_MIN_VALUE + 1; + private static final long LONG_BASE = 1000L; + private static final double DOUBLE_BASE = 10000D; + private static final float FLOAT_BASE = 100000F; + private static final String BINARY_PREFIX = "BINARY测试_"; + + @TempDir private Path temp; + + @BeforeEach + public void writeTestDataFile() throws IOException { + this.tableName = "test"; + createTable(tableName, SCHEMA); + this.records = Lists.newArrayList(); + + // records all use IDs that are in bucket id_bucket=0 + GenericRecord record = GenericRecord.create(table.schema()); + + for (int i = 0; i < INT_VALUE_COUNT; i += 1) { + records.add( + record.copy( + ImmutableMap.of( + "id", + INT_MIN_VALUE + i, + "id_long", + LONG_BASE + INT_MIN_VALUE + i, + "id_double", + DOUBLE_BASE + INT_MIN_VALUE + i, + "id_float", + FLOAT_BASE + INT_MIN_VALUE + i, + "id_string", + BINARY_PREFIX + (INT_MIN_VALUE + i), + "id_boolean", + i % 2 == 0, + "id_date", + LocalDate.parse("2021-09-05"), + "id_int_decimal", + new BigDecimal(String.valueOf(77.77)), + "id_long_decimal", + new BigDecimal(String.valueOf(88.88)), + "id_fixed_decimal", + new BigDecimal(String.valueOf(99.99))))); + } + + this.dataFile = + writeDataFile( + Files.localOutput(File.createTempFile("junit", null, temp.toFile())), + Row.of(0), + records); + + table.newAppend().appendFile(dataFile).commit(); + } + + @AfterEach + public void cleanup() throws IOException { + dropTable("test"); + } + + @Parameters(name = "vectorized = {0}, useBloomFilter = {1}") + public static Object[][] parameters() { + return new Object[][] {{false, false}, {true, false}, {false, true}, {true, true}}; + } + + @BeforeAll + public static void startMetastoreAndSpark() { + metastore = new TestHiveMetastore(); + metastore.start(); + HiveConf hiveConf = metastore.hiveConf(); + + spark = + SparkSession.builder() + .master("local[2]") + .config("spark.hadoop." + METASTOREURIS.varname, hiveConf.get(METASTOREURIS.varname)) + .enableHiveSupport() + .getOrCreate(); + + catalog = + (HiveCatalog) + CatalogUtil.loadCatalog( + HiveCatalog.class.getName(), "hive", ImmutableMap.of(), hiveConf); + + try { + catalog.createNamespace(Namespace.of("default")); + } catch (AlreadyExistsException ignored) { + // the default namespace already exists. ignore the create error + } + } + + @AfterAll + public static void stopMetastoreAndSpark() throws Exception { + catalog = null; + metastore.stop(); + metastore = null; + spark.stop(); + spark = null; + } + + protected void createTable(String name, Schema schema) { + table = catalog.createTable(TableIdentifier.of("default", name), schema); + TableOperations ops = ((BaseTable) table).operations(); + TableMetadata meta = ops.current(); + ops.commit(meta, meta.upgradeToFormatVersion(2)); + + if (useBloomFilter) { + table + .updateProperties() + .set(PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id", "true") + .set(PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_long", "true") + .set(PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_double", "true") + .set(PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_float", "true") + .set(PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_string", "true") + .set(PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_boolean", "true") + .set(PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_date", "true") + .set(PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_int_decimal", "true") + .set(PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_long_decimal", "true") + .set(PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_fixed_decimal", "true") + .commit(); + } + + table + .updateProperties() + .set(TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES, "100") // to have multiple row groups + .commit(); + if (vectorized) { + table + .updateProperties() + .set(TableProperties.PARQUET_VECTORIZATION_ENABLED, "true") + .set(TableProperties.PARQUET_BATCH_SIZE, "4") + .commit(); + } + } + + protected void dropTable(String name) { + catalog.dropTable(TableIdentifier.of("default", name)); + } + + private DataFile writeDataFile(OutputFile out, StructLike partition, List rows) + throws IOException { + FileFormat format = defaultFormat(table.properties()); + GenericAppenderFactory factory = new GenericAppenderFactory(table.schema(), table.spec()); + + boolean useBloomFilterCol1 = + PropertyUtil.propertyAsBoolean( + table.properties(), PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id", false); + factory.set( + PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id", Boolean.toString(useBloomFilterCol1)); + boolean useBloomFilterCol2 = + PropertyUtil.propertyAsBoolean( + table.properties(), PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_long", false); + factory.set( + PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_long", + Boolean.toString(useBloomFilterCol2)); + boolean useBloomFilterCol3 = + PropertyUtil.propertyAsBoolean( + table.properties(), PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_double", false); + factory.set( + PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_double", + Boolean.toString(useBloomFilterCol3)); + boolean useBloomFilterCol4 = + PropertyUtil.propertyAsBoolean( + table.properties(), PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_float", false); + factory.set( + PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_float", + Boolean.toString(useBloomFilterCol4)); + boolean useBloomFilterCol5 = + PropertyUtil.propertyAsBoolean( + table.properties(), PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_string", false); + factory.set( + PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_string", + Boolean.toString(useBloomFilterCol5)); + boolean useBloomFilterCol6 = + PropertyUtil.propertyAsBoolean( + table.properties(), PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_boolean", false); + factory.set( + PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_boolean", + Boolean.toString(useBloomFilterCol6)); + boolean useBloomFilterCol7 = + PropertyUtil.propertyAsBoolean( + table.properties(), PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_date", false); + factory.set( + PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_date", + Boolean.toString(useBloomFilterCol7)); + boolean useBloomFilterCol8 = + PropertyUtil.propertyAsBoolean( + table.properties(), + PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_int_decimal", + false); + factory.set( + PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_int_decimal", + Boolean.toString(useBloomFilterCol8)); + boolean useBloomFilterCol9 = + PropertyUtil.propertyAsBoolean( + table.properties(), + PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_long_decimal", + false); + factory.set( + PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_long_decimal", + Boolean.toString(useBloomFilterCol9)); + boolean useBloomFilterCol10 = + PropertyUtil.propertyAsBoolean( + table.properties(), + PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_fixed_decimal", + false); + factory.set( + PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_fixed_decimal", + Boolean.toString(useBloomFilterCol10)); + int blockSize = + PropertyUtil.propertyAsInt( + table.properties(), PARQUET_ROW_GROUP_SIZE_BYTES, PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT); + factory.set(PARQUET_ROW_GROUP_SIZE_BYTES, Integer.toString(blockSize)); + + FileAppender writer = factory.newAppender(out, format); + try (Closeable toClose = writer) { + writer.addAll(rows); + } + + return DataFiles.builder(table.spec()) + .withFormat(format) + .withPath(out.location()) + .withPartition(partition) + .withFileSizeInBytes(writer.length()) + .withSplitOffsets(writer.splitOffsets()) + .withMetrics(writer.metrics()) + .build(); + } + + private FileFormat defaultFormat(Map properties) { + String formatString = properties.getOrDefault(DEFAULT_FILE_FORMAT, DEFAULT_FILE_FORMAT_DEFAULT); + return FileFormat.fromString(formatString); + } + + @TestTemplate + public void testReadWithFilter() { + Dataset df = + spark + .read() + .format("iceberg") + .load(TableIdentifier.of("default", tableName).toString()) + // this is from the first row group + .filter( + "id = 30 AND id_long = 1030 AND id_double = 10030.0 AND id_float = 100030.0" + + " AND id_string = 'BINARY测试_30' AND id_boolean = true AND id_date = '2021-09-05'" + + " AND id_int_decimal = 77.77 AND id_long_decimal = 88.88 AND id_fixed_decimal = 99.99"); + + Record record = SparkValueConverter.convert(table.schema(), df.collectAsList().get(0)); + + assertThat(df.collectAsList()).as("Table should contain 1 row").hasSize(1); + assertThat(record.get(0)).as("Table should contain expected rows").isEqualTo(30); + + df = + spark + .read() + .format("iceberg") + .load(TableIdentifier.of("default", tableName).toString()) + // this is from the third row group + .filter( + "id = 250 AND id_long = 1250 AND id_double = 10250.0 AND id_float = 100250.0" + + " AND id_string = 'BINARY测试_250' AND id_boolean = true AND id_date = '2021-09-05'" + + " AND id_int_decimal = 77.77 AND id_long_decimal = 88.88 AND id_fixed_decimal = 99.99"); + + record = SparkValueConverter.convert(table.schema(), df.collectAsList().get(0)); + + assertThat(df.collectAsList()).as("Table should contain 1 row").hasSize(1); + assertThat(record.get(0)).as("Table should contain expected rows").isEqualTo(250); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkRollingFileWriters.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkRollingFileWriters.java new file mode 100644 index 000000000000..5ebeafcb8cef --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkRollingFileWriters.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.List; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileWriterFactory; +import org.apache.iceberg.io.TestRollingFileWriters; +import org.apache.iceberg.util.ArrayUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.unsafe.types.UTF8String; + +public class TestSparkRollingFileWriters extends TestRollingFileWriters { + + @Override + protected FileWriterFactory newWriterFactory( + Schema dataSchema, + List equalityFieldIds, + Schema equalityDeleteRowSchema, + Schema positionDeleteRowSchema) { + return SparkFileWriterFactory.builderFor(table) + .dataSchema(table.schema()) + .dataFileFormat(format()) + .deleteFileFormat(format()) + .equalityFieldIds(ArrayUtil.toIntArray(equalityFieldIds)) + .equalityDeleteRowSchema(equalityDeleteRowSchema) + .positionDeleteRowSchema(positionDeleteRowSchema) + .build(); + } + + @Override + protected InternalRow toRow(Integer id, String data) { + InternalRow row = new GenericInternalRow(2); + row.update(0, id); + row.update(1, UTF8String.fromString(data)); + return row; + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkScan.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkScan.java new file mode 100644 index 000000000000..dbb15ca5a743 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkScan.java @@ -0,0 +1,1061 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.puffin.StandardBlobTypes.APACHE_DATASKETCHES_THETA_V1; +import static org.apache.iceberg.spark.SystemFunctionPushDownHelper.createPartitionedTable; +import static org.apache.iceberg.spark.SystemFunctionPushDownHelper.createUnpartitionedTable; +import static org.apache.iceberg.spark.SystemFunctionPushDownHelper.timestampStrToDayOrdinal; +import static org.apache.iceberg.spark.SystemFunctionPushDownHelper.timestampStrToHourOrdinal; +import static org.apache.iceberg.spark.SystemFunctionPushDownHelper.timestampStrToMonthOrdinal; +import static org.apache.iceberg.spark.SystemFunctionPushDownHelper.timestampStrToYearOrdinal; +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.expr; +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.GenericBlobMetadata; +import org.apache.iceberg.GenericStatisticsFile; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.spark.TestBaseWithCatalog; +import org.apache.iceberg.spark.functions.BucketFunction; +import org.apache.iceberg.spark.functions.DaysFunction; +import org.apache.iceberg.spark.functions.HoursFunction; +import org.apache.iceberg.spark.functions.MonthsFunction; +import org.apache.iceberg.spark.functions.TruncateFunction; +import org.apache.iceberg.spark.functions.YearsFunction; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.connector.catalog.functions.BoundFunction; +import org.apache.spark.sql.connector.expressions.Expression; +import org.apache.spark.sql.connector.expressions.FieldReference; +import org.apache.spark.sql.connector.expressions.LiteralValue; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.UserDefinedScalarFunc; +import org.apache.spark.sql.connector.expressions.filter.And; +import org.apache.spark.sql.connector.expressions.filter.Not; +import org.apache.spark.sql.connector.expressions.filter.Or; +import org.apache.spark.sql.connector.expressions.filter.Predicate; +import org.apache.spark.sql.connector.read.Batch; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.connector.read.Statistics; +import org.apache.spark.sql.connector.read.SupportsPushDownV2Filters; +import org.apache.spark.sql.connector.read.colstats.ColumnStatistics; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestSparkScan extends TestBaseWithCatalog { + + private static final String DUMMY_BLOB_TYPE = "sum-data-size-bytes-v1"; + + @Parameter(index = 3) + private String format; + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}, format = {3}") + public static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.HADOOP.catalogName(), + SparkCatalogConfig.HADOOP.implementation(), + SparkCatalogConfig.HADOOP.properties(), + "parquet" + }, + { + SparkCatalogConfig.HADOOP.catalogName(), + SparkCatalogConfig.HADOOP.implementation(), + SparkCatalogConfig.HADOOP.properties(), + "avro" + }, + { + SparkCatalogConfig.HADOOP.catalogName(), + SparkCatalogConfig.HADOOP.implementation(), + SparkCatalogConfig.HADOOP.properties(), + "orc" + } + }; + } + + @BeforeEach + public void useCatalog() { + sql("USE %s", catalogName); + } + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testEstimatedRowCount() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id BIGINT, date DATE) USING iceberg TBLPROPERTIES('%s' = '%s')", + tableName, TableProperties.DEFAULT_FILE_FORMAT, format); + + Dataset df = + spark + .range(10000) + .withColumn("date", date_add(expr("DATE '1970-01-01'"), expr("CAST(id AS INT)"))) + .select("id", "date"); + + df.coalesce(1).writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + SparkScanBuilder scanBuilder = + new SparkScanBuilder(spark, table, CaseInsensitiveStringMap.empty()); + SparkScan scan = (SparkScan) scanBuilder.build(); + Statistics stats = scan.estimateStatistics(); + + assertThat(stats.numRows().getAsLong()).isEqualTo(10000L); + } + + @TestTemplate + public void testTableWithoutColStats() throws NoSuchTableException { + sql("CREATE TABLE %s (id int, data string) USING iceberg", tableName); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "a"), + new SimpleRecord(4, "b")); + spark + .createDataset(records, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + + Table table = validationCatalog.loadTable(tableIdent); + + SparkScanBuilder scanBuilder = + new SparkScanBuilder(spark, table, CaseInsensitiveStringMap.empty()); + SparkScan scan = (SparkScan) scanBuilder.build(); + + Map reportColStatsDisabled = + ImmutableMap.of( + SQLConf.CBO_ENABLED().key(), "true", SparkSQLProperties.REPORT_COLUMN_STATS, "false"); + + Map reportColStatsEnabled = + ImmutableMap.of(SQLConf.CBO_ENABLED().key(), "true"); + + checkColStatisticsNotReported(scan, 4L); + withSQLConf(reportColStatsDisabled, () -> checkColStatisticsNotReported(scan, 4L)); + // The expected col NDVs are nulls + withSQLConf( + reportColStatsEnabled, () -> checkColStatisticsReported(scan, 4L, Maps.newHashMap())); + } + + @TestTemplate + public void testTableWithoutApacheDatasketchColStat() throws NoSuchTableException { + sql("CREATE TABLE %s (id int, data string) USING iceberg", tableName); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "a"), + new SimpleRecord(4, "b")); + spark + .createDataset(records, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + + Table table = validationCatalog.loadTable(tableIdent); + long snapshotId = table.currentSnapshot().snapshotId(); + + SparkScanBuilder scanBuilder = + new SparkScanBuilder(spark, table, CaseInsensitiveStringMap.empty()); + SparkScan scan = (SparkScan) scanBuilder.build(); + + Map reportColStatsDisabled = + ImmutableMap.of( + SQLConf.CBO_ENABLED().key(), "true", SparkSQLProperties.REPORT_COLUMN_STATS, "false"); + + Map reportColStatsEnabled = + ImmutableMap.of(SQLConf.CBO_ENABLED().key(), "true"); + + GenericStatisticsFile statisticsFile = + new GenericStatisticsFile( + snapshotId, + "/test/statistics/file.puffin", + 100, + 42, + ImmutableList.of( + new GenericBlobMetadata( + DUMMY_BLOB_TYPE, + snapshotId, + 1, + ImmutableList.of(1), + ImmutableMap.of("data_size", "4")))); + + table.updateStatistics().setStatistics(snapshotId, statisticsFile).commit(); + + checkColStatisticsNotReported(scan, 4L); + withSQLConf(reportColStatsDisabled, () -> checkColStatisticsNotReported(scan, 4L)); + // The expected col NDVs are nulls + withSQLConf( + reportColStatsEnabled, () -> checkColStatisticsReported(scan, 4L, Maps.newHashMap())); + } + + @TestTemplate + public void testTableWithOneColStats() throws NoSuchTableException { + sql("CREATE TABLE %s (id int, data string) USING iceberg", tableName); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "a"), + new SimpleRecord(4, "b")); + spark + .createDataset(records, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + + Table table = validationCatalog.loadTable(tableIdent); + long snapshotId = table.currentSnapshot().snapshotId(); + + SparkScanBuilder scanBuilder = + new SparkScanBuilder(spark, table, CaseInsensitiveStringMap.empty()); + SparkScan scan = (SparkScan) scanBuilder.build(); + + Map reportColStatsDisabled = + ImmutableMap.of( + SQLConf.CBO_ENABLED().key(), "true", SparkSQLProperties.REPORT_COLUMN_STATS, "false"); + + Map reportColStatsEnabled = + ImmutableMap.of(SQLConf.CBO_ENABLED().key(), "true"); + + GenericStatisticsFile statisticsFile = + new GenericStatisticsFile( + snapshotId, + "/test/statistics/file.puffin", + 100, + 42, + ImmutableList.of( + new GenericBlobMetadata( + APACHE_DATASKETCHES_THETA_V1, + snapshotId, + 1, + ImmutableList.of(1), + ImmutableMap.of("ndv", "4")))); + + table.updateStatistics().setStatistics(snapshotId, statisticsFile).commit(); + + checkColStatisticsNotReported(scan, 4L); + withSQLConf(reportColStatsDisabled, () -> checkColStatisticsNotReported(scan, 4L)); + + Map expectedOneNDV = Maps.newHashMap(); + expectedOneNDV.put("id", 4L); + withSQLConf(reportColStatsEnabled, () -> checkColStatisticsReported(scan, 4L, expectedOneNDV)); + } + + @TestTemplate + public void testTableWithOneApacheDatasketchColStatAndOneDifferentColStat() + throws NoSuchTableException { + sql("CREATE TABLE %s (id int, data string) USING iceberg", tableName); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "a"), + new SimpleRecord(4, "b")); + spark + .createDataset(records, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + + Table table = validationCatalog.loadTable(tableIdent); + long snapshotId = table.currentSnapshot().snapshotId(); + + SparkScanBuilder scanBuilder = + new SparkScanBuilder(spark, table, CaseInsensitiveStringMap.empty()); + SparkScan scan = (SparkScan) scanBuilder.build(); + + Map reportColStatsDisabled = + ImmutableMap.of( + SQLConf.CBO_ENABLED().key(), "true", SparkSQLProperties.REPORT_COLUMN_STATS, "false"); + + Map reportColStatsEnabled = + ImmutableMap.of(SQLConf.CBO_ENABLED().key(), "true"); + + GenericStatisticsFile statisticsFile = + new GenericStatisticsFile( + snapshotId, + "/test/statistics/file.puffin", + 100, + 42, + ImmutableList.of( + new GenericBlobMetadata( + APACHE_DATASKETCHES_THETA_V1, + snapshotId, + 1, + ImmutableList.of(1), + ImmutableMap.of("ndv", "4")), + new GenericBlobMetadata( + DUMMY_BLOB_TYPE, + snapshotId, + 1, + ImmutableList.of(1), + ImmutableMap.of("data_size", "2")))); + + table.updateStatistics().setStatistics(snapshotId, statisticsFile).commit(); + + checkColStatisticsNotReported(scan, 4L); + withSQLConf(reportColStatsDisabled, () -> checkColStatisticsNotReported(scan, 4L)); + + Map expectedOneNDV = Maps.newHashMap(); + expectedOneNDV.put("id", 4L); + withSQLConf(reportColStatsEnabled, () -> checkColStatisticsReported(scan, 4L, expectedOneNDV)); + } + + @TestTemplate + public void testTableWithTwoColStats() throws NoSuchTableException { + sql("CREATE TABLE %s (id int, data string) USING iceberg", tableName); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "a"), + new SimpleRecord(4, "b")); + spark + .createDataset(records, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + + Table table = validationCatalog.loadTable(tableIdent); + long snapshotId = table.currentSnapshot().snapshotId(); + + SparkScanBuilder scanBuilder = + new SparkScanBuilder(spark, table, CaseInsensitiveStringMap.empty()); + SparkScan scan = (SparkScan) scanBuilder.build(); + + Map reportColStatsDisabled = + ImmutableMap.of( + SQLConf.CBO_ENABLED().key(), "true", SparkSQLProperties.REPORT_COLUMN_STATS, "false"); + + Map reportColStatsEnabled = + ImmutableMap.of(SQLConf.CBO_ENABLED().key(), "true"); + + GenericStatisticsFile statisticsFile = + new GenericStatisticsFile( + snapshotId, + "/test/statistics/file.puffin", + 100, + 42, + ImmutableList.of( + new GenericBlobMetadata( + APACHE_DATASKETCHES_THETA_V1, + snapshotId, + 1, + ImmutableList.of(1), + ImmutableMap.of("ndv", "4")), + new GenericBlobMetadata( + APACHE_DATASKETCHES_THETA_V1, + snapshotId, + 1, + ImmutableList.of(2), + ImmutableMap.of("ndv", "2")))); + + table.updateStatistics().setStatistics(snapshotId, statisticsFile).commit(); + + checkColStatisticsNotReported(scan, 4L); + withSQLConf(reportColStatsDisabled, () -> checkColStatisticsNotReported(scan, 4L)); + + Map expectedTwoNDVs = Maps.newHashMap(); + expectedTwoNDVs.put("id", 4L); + expectedTwoNDVs.put("data", 2L); + withSQLConf(reportColStatsEnabled, () -> checkColStatisticsReported(scan, 4L, expectedTwoNDVs)); + } + + @TestTemplate + public void testUnpartitionedYears() throws Exception { + createUnpartitionedTable(spark, tableName); + + SparkScanBuilder builder = scanBuilder(); + + YearsFunction.TimestampToYearsFunction function = new YearsFunction.TimestampToYearsFunction(); + UserDefinedScalarFunc udf = toUDF(function, expressions(fieldRef("ts"))); + Predicate predicate = + new Predicate( + "=", + expressions( + udf, intLit(timestampStrToYearOrdinal("2017-11-22T00:00:00.000000+00:00")))); + pushFilters(builder, predicate); + Batch scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(10); + + // NOT Equal + builder = scanBuilder(); + + predicate = new Not(predicate); + pushFilters(builder, predicate); + scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(10); + } + + @TestTemplate + public void testPartitionedYears() throws Exception { + createPartitionedTable(spark, tableName, "years(ts)"); + + SparkScanBuilder builder = scanBuilder(); + + YearsFunction.TimestampToYearsFunction function = new YearsFunction.TimestampToYearsFunction(); + UserDefinedScalarFunc udf = toUDF(function, expressions(fieldRef("ts"))); + Predicate predicate = + new Predicate( + "=", + expressions( + udf, intLit(timestampStrToYearOrdinal("2017-11-22T00:00:00.000000+00:00")))); + pushFilters(builder, predicate); + Batch scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(5); + + // NOT Equal + builder = scanBuilder(); + + predicate = new Not(predicate); + pushFilters(builder, predicate); + scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(5); + } + + @TestTemplate + public void testUnpartitionedMonths() throws Exception { + createUnpartitionedTable(spark, tableName); + + SparkScanBuilder builder = scanBuilder(); + + MonthsFunction.TimestampToMonthsFunction function = + new MonthsFunction.TimestampToMonthsFunction(); + UserDefinedScalarFunc udf = toUDF(function, expressions(fieldRef("ts"))); + Predicate predicate = + new Predicate( + ">", + expressions( + udf, intLit(timestampStrToMonthOrdinal("2017-11-22T00:00:00.000000+00:00")))); + pushFilters(builder, predicate); + Batch scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(10); + + // NOT GT + builder = scanBuilder(); + + predicate = new Not(predicate); + pushFilters(builder, predicate); + scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(10); + } + + @TestTemplate + public void testPartitionedMonths() throws Exception { + createPartitionedTable(spark, tableName, "months(ts)"); + + SparkScanBuilder builder = scanBuilder(); + + MonthsFunction.TimestampToMonthsFunction function = + new MonthsFunction.TimestampToMonthsFunction(); + UserDefinedScalarFunc udf = toUDF(function, expressions(fieldRef("ts"))); + Predicate predicate = + new Predicate( + ">", + expressions( + udf, intLit(timestampStrToMonthOrdinal("2017-11-22T00:00:00.000000+00:00")))); + pushFilters(builder, predicate); + Batch scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(5); + + // NOT GT + builder = scanBuilder(); + + predicate = new Not(predicate); + pushFilters(builder, predicate); + scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(5); + } + + @TestTemplate + public void testUnpartitionedDays() throws Exception { + createUnpartitionedTable(spark, tableName); + + SparkScanBuilder builder = scanBuilder(); + + DaysFunction.TimestampToDaysFunction function = new DaysFunction.TimestampToDaysFunction(); + UserDefinedScalarFunc udf = toUDF(function, expressions(fieldRef("ts"))); + Predicate predicate = + new Predicate( + "<", + expressions( + udf, dateLit(timestampStrToDayOrdinal("2018-11-20T00:00:00.000000+00:00")))); + pushFilters(builder, predicate); + Batch scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(10); + + // NOT LT + builder = scanBuilder(); + + predicate = new Not(predicate); + pushFilters(builder, predicate); + scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(10); + } + + @TestTemplate + public void testPartitionedDays() throws Exception { + createPartitionedTable(spark, tableName, "days(ts)"); + + SparkScanBuilder builder = scanBuilder(); + + DaysFunction.TimestampToDaysFunction function = new DaysFunction.TimestampToDaysFunction(); + UserDefinedScalarFunc udf = toUDF(function, expressions(fieldRef("ts"))); + Predicate predicate = + new Predicate( + "<", + expressions( + udf, dateLit(timestampStrToDayOrdinal("2018-11-20T00:00:00.000000+00:00")))); + pushFilters(builder, predicate); + Batch scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(5); + + // NOT LT + builder = scanBuilder(); + + predicate = new Not(predicate); + pushFilters(builder, predicate); + scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(5); + } + + @TestTemplate + public void testUnpartitionedHours() throws Exception { + createUnpartitionedTable(spark, tableName); + + SparkScanBuilder builder = scanBuilder(); + + HoursFunction.TimestampToHoursFunction function = new HoursFunction.TimestampToHoursFunction(); + UserDefinedScalarFunc udf = toUDF(function, expressions(fieldRef("ts"))); + Predicate predicate = + new Predicate( + ">=", + expressions( + udf, intLit(timestampStrToHourOrdinal("2017-11-22T06:02:09.243857+00:00")))); + pushFilters(builder, predicate); + Batch scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(10); + + // NOT GTEQ + builder = scanBuilder(); + + predicate = new Not(predicate); + pushFilters(builder, predicate); + scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(10); + } + + @TestTemplate + public void testPartitionedHours() throws Exception { + createPartitionedTable(spark, tableName, "hours(ts)"); + + SparkScanBuilder builder = scanBuilder(); + + HoursFunction.TimestampToHoursFunction function = new HoursFunction.TimestampToHoursFunction(); + UserDefinedScalarFunc udf = toUDF(function, expressions(fieldRef("ts"))); + Predicate predicate = + new Predicate( + ">=", + expressions( + udf, intLit(timestampStrToHourOrdinal("2017-11-22T06:02:09.243857+00:00")))); + pushFilters(builder, predicate); + Batch scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(8); + + // NOT GTEQ + builder = scanBuilder(); + + predicate = new Not(predicate); + pushFilters(builder, predicate); + scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(2); + } + + @TestTemplate + public void testUnpartitionedBucketLong() throws Exception { + createUnpartitionedTable(spark, tableName); + + SparkScanBuilder builder = scanBuilder(); + + BucketFunction.BucketLong function = new BucketFunction.BucketLong(DataTypes.LongType); + UserDefinedScalarFunc udf = toUDF(function, expressions(intLit(5), fieldRef("id"))); + Predicate predicate = new Predicate(">=", expressions(udf, intLit(2))); + pushFilters(builder, predicate); + Batch scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(10); + + // NOT GTEQ + builder = scanBuilder(); + + predicate = new Not(predicate); + pushFilters(builder, predicate); + scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(10); + } + + @TestTemplate + public void testPartitionedBucketLong() throws Exception { + createPartitionedTable(spark, tableName, "bucket(5, id)"); + + SparkScanBuilder builder = scanBuilder(); + + BucketFunction.BucketLong function = new BucketFunction.BucketLong(DataTypes.LongType); + UserDefinedScalarFunc udf = toUDF(function, expressions(intLit(5), fieldRef("id"))); + Predicate predicate = new Predicate(">=", expressions(udf, intLit(2))); + pushFilters(builder, predicate); + Batch scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(6); + + // NOT GTEQ + builder = scanBuilder(); + + predicate = new Not(predicate); + pushFilters(builder, predicate); + scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(4); + } + + @TestTemplate + public void testUnpartitionedBucketString() throws Exception { + createUnpartitionedTable(spark, tableName); + + SparkScanBuilder builder = scanBuilder(); + + BucketFunction.BucketString function = new BucketFunction.BucketString(); + UserDefinedScalarFunc udf = toUDF(function, expressions(intLit(5), fieldRef("data"))); + Predicate predicate = new Predicate("<=", expressions(udf, intLit(2))); + pushFilters(builder, predicate); + Batch scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(10); + + // NOT LTEQ + builder = scanBuilder(); + + predicate = new Not(predicate); + pushFilters(builder, predicate); + scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(10); + } + + @TestTemplate + public void testPartitionedBucketString() throws Exception { + createPartitionedTable(spark, tableName, "bucket(5, data)"); + + SparkScanBuilder builder = scanBuilder(); + + BucketFunction.BucketString function = new BucketFunction.BucketString(); + UserDefinedScalarFunc udf = toUDF(function, expressions(intLit(5), fieldRef("data"))); + Predicate predicate = new Predicate("<=", expressions(udf, intLit(2))); + pushFilters(builder, predicate); + Batch scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(6); + + // NOT LTEQ + builder = scanBuilder(); + + predicate = new Not(predicate); + pushFilters(builder, predicate); + scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(4); + } + + @TestTemplate + public void testUnpartitionedTruncateString() throws Exception { + createUnpartitionedTable(spark, tableName); + + SparkScanBuilder builder = scanBuilder(); + + TruncateFunction.TruncateString function = new TruncateFunction.TruncateString(); + UserDefinedScalarFunc udf = toUDF(function, expressions(intLit(4), fieldRef("data"))); + Predicate predicate = new Predicate("<>", expressions(udf, stringLit("data"))); + pushFilters(builder, predicate); + Batch scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(10); + + // NOT NotEqual + builder = scanBuilder(); + + predicate = new Not(predicate); + pushFilters(builder, predicate); + scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(10); + } + + @TestTemplate + public void testPartitionedTruncateString() throws Exception { + createPartitionedTable(spark, tableName, "truncate(4, data)"); + + SparkScanBuilder builder = scanBuilder(); + + TruncateFunction.TruncateString function = new TruncateFunction.TruncateString(); + UserDefinedScalarFunc udf = toUDF(function, expressions(intLit(4), fieldRef("data"))); + Predicate predicate = new Predicate("<>", expressions(udf, stringLit("data"))); + pushFilters(builder, predicate); + Batch scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(5); + + // NOT NotEqual + builder = scanBuilder(); + + predicate = new Not(predicate); + pushFilters(builder, predicate); + scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(5); + } + + @TestTemplate + public void testUnpartitionedIsNull() throws Exception { + createUnpartitionedTable(spark, tableName); + + SparkScanBuilder builder = scanBuilder(); + + TruncateFunction.TruncateString function = new TruncateFunction.TruncateString(); + UserDefinedScalarFunc udf = toUDF(function, expressions(intLit(4), fieldRef("data"))); + Predicate predicate = new Predicate("IS_NULL", expressions(udf)); + pushFilters(builder, predicate); + Batch scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(10); + + // NOT IsNull + builder = scanBuilder(); + + predicate = new Not(predicate); + pushFilters(builder, predicate); + scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(10); + } + + @TestTemplate + public void testPartitionedIsNull() throws Exception { + createPartitionedTable(spark, tableName, "truncate(4, data)"); + + SparkScanBuilder builder = scanBuilder(); + + TruncateFunction.TruncateString function = new TruncateFunction.TruncateString(); + UserDefinedScalarFunc udf = toUDF(function, expressions(intLit(4), fieldRef("data"))); + Predicate predicate = new Predicate("IS_NULL", expressions(udf)); + pushFilters(builder, predicate); + Batch scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(0); + + // NOT IsNULL + builder = scanBuilder(); + + predicate = new Not(predicate); + pushFilters(builder, predicate); + scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(10); + } + + @TestTemplate + public void testUnpartitionedIsNotNull() throws Exception { + createUnpartitionedTable(spark, tableName); + + SparkScanBuilder builder = scanBuilder(); + + TruncateFunction.TruncateString function = new TruncateFunction.TruncateString(); + UserDefinedScalarFunc udf = toUDF(function, expressions(intLit(4), fieldRef("data"))); + Predicate predicate = new Predicate("IS_NOT_NULL", expressions(udf)); + pushFilters(builder, predicate); + Batch scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(10); + + // NOT IsNotNull + builder = scanBuilder(); + + predicate = new Not(predicate); + pushFilters(builder, predicate); + scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(10); + } + + @TestTemplate + public void testPartitionedIsNotNull() throws Exception { + createPartitionedTable(spark, tableName, "truncate(4, data)"); + + SparkScanBuilder builder = scanBuilder(); + + TruncateFunction.TruncateString function = new TruncateFunction.TruncateString(); + UserDefinedScalarFunc udf = toUDF(function, expressions(intLit(4), fieldRef("data"))); + Predicate predicate = new Predicate("IS_NOT_NULL", expressions(udf)); + pushFilters(builder, predicate); + Batch scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(10); + + // NOT IsNotNULL + builder = scanBuilder(); + + predicate = new Not(predicate); + pushFilters(builder, predicate); + scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(0); + } + + @TestTemplate + public void testUnpartitionedAnd() throws Exception { + createUnpartitionedTable(spark, tableName); + + SparkScanBuilder builder = scanBuilder(); + + YearsFunction.TimestampToYearsFunction tsToYears = new YearsFunction.TimestampToYearsFunction(); + UserDefinedScalarFunc udf1 = toUDF(tsToYears, expressions(fieldRef("ts"))); + Predicate predicate1 = new Predicate("=", expressions(udf1, intLit(2017 - 1970))); + + BucketFunction.BucketLong bucketLong = new BucketFunction.BucketLong(DataTypes.LongType); + UserDefinedScalarFunc udf = toUDF(bucketLong, expressions(intLit(5), fieldRef("id"))); + Predicate predicate2 = new Predicate(">=", expressions(udf, intLit(2))); + Predicate predicate = new And(predicate1, predicate2); + + pushFilters(builder, predicate); + Batch scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(10); + + // NOT (years(ts) = 47 AND bucket(id, 5) >= 2) + builder = scanBuilder(); + + predicate = new Not(predicate); + pushFilters(builder, predicate); + scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(10); + } + + @TestTemplate + public void testPartitionedAnd() throws Exception { + createPartitionedTable(spark, tableName, "years(ts), bucket(5, id)"); + + SparkScanBuilder builder = scanBuilder(); + + YearsFunction.TimestampToYearsFunction tsToYears = new YearsFunction.TimestampToYearsFunction(); + UserDefinedScalarFunc udf1 = toUDF(tsToYears, expressions(fieldRef("ts"))); + Predicate predicate1 = new Predicate("=", expressions(udf1, intLit(2017 - 1970))); + + BucketFunction.BucketLong bucketLong = new BucketFunction.BucketLong(DataTypes.LongType); + UserDefinedScalarFunc udf = toUDF(bucketLong, expressions(intLit(5), fieldRef("id"))); + Predicate predicate2 = new Predicate(">=", expressions(udf, intLit(2))); + Predicate predicate = new And(predicate1, predicate2); + + pushFilters(builder, predicate); + Batch scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(1); + + // NOT (years(ts) = 47 AND bucket(id, 5) >= 2) + builder = scanBuilder(); + + predicate = new Not(predicate); + pushFilters(builder, predicate); + scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(9); + } + + @TestTemplate + public void testUnpartitionedOr() throws Exception { + createUnpartitionedTable(spark, tableName); + + SparkScanBuilder builder = scanBuilder(); + + YearsFunction.TimestampToYearsFunction tsToYears = new YearsFunction.TimestampToYearsFunction(); + UserDefinedScalarFunc udf1 = toUDF(tsToYears, expressions(fieldRef("ts"))); + Predicate predicate1 = new Predicate("=", expressions(udf1, intLit(2017 - 1970))); + + BucketFunction.BucketLong bucketLong = new BucketFunction.BucketLong(DataTypes.LongType); + UserDefinedScalarFunc udf = toUDF(bucketLong, expressions(intLit(5), fieldRef("id"))); + Predicate predicate2 = new Predicate(">=", expressions(udf, intLit(2))); + Predicate predicate = new Or(predicate1, predicate2); + + pushFilters(builder, predicate); + Batch scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(10); + + // NOT (years(ts) = 47 OR bucket(id, 5) >= 2) + builder = scanBuilder(); + + predicate = new Not(predicate); + pushFilters(builder, predicate); + scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(10); + } + + @TestTemplate + public void testPartitionedOr() throws Exception { + createPartitionedTable(spark, tableName, "years(ts), bucket(5, id)"); + + SparkScanBuilder builder = scanBuilder(); + + YearsFunction.TimestampToYearsFunction tsToYears = new YearsFunction.TimestampToYearsFunction(); + UserDefinedScalarFunc udf1 = toUDF(tsToYears, expressions(fieldRef("ts"))); + Predicate predicate1 = new Predicate("=", expressions(udf1, intLit(2018 - 1970))); + + BucketFunction.BucketLong bucketLong = new BucketFunction.BucketLong(DataTypes.LongType); + UserDefinedScalarFunc udf = toUDF(bucketLong, expressions(intLit(5), fieldRef("id"))); + Predicate predicate2 = new Predicate(">=", expressions(udf, intLit(2))); + Predicate predicate = new Or(predicate1, predicate2); + + pushFilters(builder, predicate); + Batch scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(6); + + // NOT (years(ts) = 48 OR bucket(id, 5) >= 2) + builder = scanBuilder(); + + predicate = new Not(predicate); + pushFilters(builder, predicate); + scan = builder.build().toBatch(); + + assertThat(scan.planInputPartitions().length).isEqualTo(4); + } + + private SparkScanBuilder scanBuilder() throws Exception { + Table table = Spark3Util.loadIcebergTable(spark, tableName); + CaseInsensitiveStringMap options = + new CaseInsensitiveStringMap(ImmutableMap.of("path", tableName)); + + return new SparkScanBuilder(spark, table, options); + } + + private void pushFilters(ScanBuilder scan, Predicate... predicates) { + assertThat(scan).isInstanceOf(SupportsPushDownV2Filters.class); + SupportsPushDownV2Filters filterable = (SupportsPushDownV2Filters) scan; + filterable.pushPredicates(predicates); + } + + private Expression[] expressions(Expression... expressions) { + return expressions; + } + + private void checkColStatisticsNotReported(SparkScan scan, long expectedRowCount) { + Statistics stats = scan.estimateStatistics(); + assertThat(stats.numRows().getAsLong()).isEqualTo(expectedRowCount); + + Map columnStats = stats.columnStats(); + assertThat(columnStats).isEmpty(); + } + + private void checkColStatisticsReported( + SparkScan scan, long expectedRowCount, Map expectedNDVs) { + Statistics stats = scan.estimateStatistics(); + assertThat(stats.numRows().getAsLong()).isEqualTo(expectedRowCount); + + Map columnStats = stats.columnStats(); + if (expectedNDVs.isEmpty()) { + assertThat(columnStats.values().stream().allMatch(value -> value.distinctCount().isEmpty())) + .isTrue(); + } else { + for (Map.Entry entry : expectedNDVs.entrySet()) { + assertThat( + columnStats.get(FieldReference.column(entry.getKey())).distinctCount().getAsLong()) + .isEqualTo(entry.getValue()); + } + } + } + + private static LiteralValue intLit(int value) { + return LiteralValue.apply(value, DataTypes.IntegerType); + } + + private static LiteralValue dateLit(int value) { + return LiteralValue.apply(value, DataTypes.DateType); + } + + private static LiteralValue stringLit(String value) { + return LiteralValue.apply(value, DataTypes.StringType); + } + + private static NamedReference fieldRef(String col) { + return FieldReference.apply(col); + } + + private static UserDefinedScalarFunc toUDF(BoundFunction function, Expression[] expressions) { + return new UserDefinedScalarFunc(function.name(), function.canonicalName(), expressions); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkStagedScan.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkStagedScan.java new file mode 100644 index 000000000000..6ce2ce623835 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkStagedScan.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.IOException; +import java.util.List; +import java.util.UUID; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Table; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.CatalogTestBase; +import org.apache.iceberg.spark.ScanTaskSetManager; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestSparkStagedScan extends CatalogTestBase { + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testTaskSetLoading() throws NoSuchTableException, IOException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).as("Should produce 1 snapshot").hasSize(1); + + try (CloseableIterable fileScanTasks = table.newScan().planFiles()) { + ScanTaskSetManager taskSetManager = ScanTaskSetManager.get(); + String setID = UUID.randomUUID().toString(); + taskSetManager.stageTasks(table, setID, ImmutableList.copyOf(fileScanTasks)); + + // load the staged file set + Dataset scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, setID) + .load(tableName); + + // write the records back essentially duplicating data + scanDF.writeTo(tableName).append(); + } + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "a"), row(1, "a"), row(2, "b"), row(2, "b")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void testTaskSetPlanning() throws NoSuchTableException, IOException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.coalesce(1).writeTo(tableName).append(); + df.coalesce(1).writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).as("Should produce 2 snapshot").hasSize(2); + + try (CloseableIterable fileScanTasks = table.newScan().planFiles()) { + ScanTaskSetManager taskSetManager = ScanTaskSetManager.get(); + String setID = UUID.randomUUID().toString(); + List tasks = ImmutableList.copyOf(fileScanTasks); + taskSetManager.stageTasks(table, setID, tasks); + + // load the staged file set and make sure each file is in a separate split + Dataset scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, setID) + .option(SparkReadOptions.SPLIT_SIZE, tasks.get(0).file().fileSizeInBytes()) + .load(tableName); + assertThat(scanDF.javaRDD().getNumPartitions()) + .as("Num partitions should match") + .isEqualTo(2); + + // load the staged file set and make sure we combine both files into a single split + scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, setID) + .option(SparkReadOptions.SPLIT_SIZE, Long.MAX_VALUE) + .load(tableName); + assertThat(scanDF.javaRDD().getNumPartitions()) + .as("Num partitions should match") + .isEqualTo(1); + } + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkTable.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkTable.java new file mode 100644 index 000000000000..46ee484b39ea --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkTable.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.apache.iceberg.spark.CatalogTestBase; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.connector.catalog.CatalogManager; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestSparkTable extends CatalogTestBase { + + @BeforeEach + public void createTable() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + } + + @AfterEach + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testTableEquality() throws NoSuchTableException { + CatalogManager catalogManager = spark.sessionState().catalogManager(); + TableCatalog catalog = (TableCatalog) catalogManager.catalog(catalogName); + Identifier identifier = Identifier.of(tableIdent.namespace().levels(), tableIdent.name()); + SparkTable table1 = (SparkTable) catalog.loadTable(identifier); + SparkTable table2 = (SparkTable) catalog.loadTable(identifier); + + // different instances pointing to the same table must be equivalent + assertThat(table1).as("References must be different").isNotSameAs(table2); + assertThat(table1).as("Tables must be equivalent").isEqualTo(table2); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkWriterMetrics.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkWriterMetrics.java new file mode 100644 index 000000000000..06ecc20c2fc3 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkWriterMetrics.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Table; +import org.apache.iceberg.io.FileWriterFactory; +import org.apache.iceberg.io.TestWriterMetrics; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.unsafe.types.UTF8String; + +public class TestSparkWriterMetrics extends TestWriterMetrics { + + public TestSparkWriterMetrics(FileFormat fileFormat) { + super(fileFormat); + } + + @Override + protected FileWriterFactory newWriterFactory(Table sourceTable) { + return SparkFileWriterFactory.builderFor(sourceTable) + .dataSchema(sourceTable.schema()) + .dataFileFormat(fileFormat) + .deleteFileFormat(fileFormat) + .positionDeleteRowSchema(sourceTable.schema()) + .build(); + } + + @Override + protected InternalRow toRow(Integer id, String data, boolean boolValue, Long longValue) { + InternalRow row = new GenericInternalRow(3); + row.update(0, id); + row.update(1, UTF8String.fromString(data)); + + InternalRow nested = new GenericInternalRow(2); + nested.update(0, boolValue); + nested.update(1, longValue); + + row.update(2, nested); + return row; + } + + @Override + protected InternalRow toGenericRow(int value, int repeated) { + InternalRow row = new GenericInternalRow(repeated); + for (int i = 0; i < repeated; i++) { + row.update(i, value); + } + return row; + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestStreamingOffset.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestStreamingOffset.java new file mode 100644 index 000000000000..d55e718ff2d3 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestStreamingOffset.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.fasterxml.jackson.databind.node.ObjectNode; +import java.util.Arrays; +import org.apache.iceberg.util.JsonUtil; +import org.junit.jupiter.api.Test; + +public class TestStreamingOffset { + + @Test + public void testJsonConversion() { + StreamingOffset[] expected = + new StreamingOffset[] { + new StreamingOffset(System.currentTimeMillis(), 1L, false), + new StreamingOffset(System.currentTimeMillis(), 2L, false), + new StreamingOffset(System.currentTimeMillis(), 3L, false), + new StreamingOffset(System.currentTimeMillis(), 4L, true) + }; + assertThat(Arrays.stream(expected).map(elem -> StreamingOffset.fromJson(elem.json())).toArray()) + .as("StreamingOffsets should match") + .isEqualTo(expected); + } + + @Test + public void testToJson() throws Exception { + StreamingOffset expected = new StreamingOffset(System.currentTimeMillis(), 1L, false); + ObjectNode actual = JsonUtil.mapper().createObjectNode(); + actual.put("version", 1); + actual.put("snapshot_id", expected.snapshotId()); + actual.put("position", 1L); + actual.put("scan_all_files", false); + String expectedJson = expected.json(); + String actualJson = JsonUtil.mapper().writeValueAsString(actual); + assertThat(actualJson).isEqualTo(expectedJson); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestStructuredStreaming.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestStructuredStreaming.java new file mode 100644 index 000000000000..19da8c7d50c4 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestStructuredStreaming.java @@ -0,0 +1,314 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.util.List; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.execution.streaming.MemoryStream; +import org.apache.spark.sql.streaming.DataStreamWriter; +import org.apache.spark.sql.streaming.StreamingQuery; +import org.apache.spark.sql.streaming.StreamingQueryException; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import scala.Option; +import scala.collection.JavaConverters; + +public class TestStructuredStreaming { + + private static final Configuration CONF = new Configuration(); + private static final Schema SCHEMA = + new Schema( + optional(1, "id", Types.IntegerType.get()), optional(2, "data", Types.StringType.get())); + private static SparkSession spark = null; + + @TempDir private Path temp; + + @BeforeAll + public static void startSpark() { + TestStructuredStreaming.spark = + SparkSession.builder() + .master("local[2]") + .config("spark.sql.shuffle.partitions", 4) + .getOrCreate(); + } + + @AfterAll + public static void stopSpark() { + SparkSession currentSpark = TestStructuredStreaming.spark; + TestStructuredStreaming.spark = null; + currentSpark.stop(); + } + + @Test + public void testStreamingWriteAppendMode() throws Exception { + File parent = temp.resolve("parquet").toFile(); + File location = new File(parent, "test-table"); + File checkpoint = new File(parent, "checkpoint"); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + List expected = + Lists.newArrayList( + new SimpleRecord(1, "1"), + new SimpleRecord(2, "2"), + new SimpleRecord(3, "3"), + new SimpleRecord(4, "4")); + + MemoryStream inputStream = newMemoryStream(1, spark.sqlContext(), Encoders.INT()); + DataStreamWriter streamWriter = + inputStream + .toDF() + .selectExpr("value AS id", "CAST (value AS STRING) AS data") + .writeStream() + .outputMode("append") + .format("iceberg") + .option("checkpointLocation", checkpoint.toString()) + .option("path", location.toString()); + + try { + // start the original query with checkpointing + StreamingQuery query = streamWriter.start(); + List batch1 = Lists.newArrayList(1, 2); + send(batch1, inputStream); + query.processAllAvailable(); + List batch2 = Lists.newArrayList(3, 4); + send(batch2, inputStream); + query.processAllAvailable(); + query.stop(); + + // remove the last commit to force Spark to reprocess batch #1 + File lastCommitFile = new File(checkpoint + "/commits/1"); + deleteFileAndCrc(lastCommitFile); + + // restart the query from the checkpoint + StreamingQuery restartedQuery = streamWriter.start(); + restartedQuery.processAllAvailable(); + + // ensure the write was idempotent + Dataset result = spark.read().format("iceberg").load(location.toString()); + List actual = + result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + + assertThat(actual).as("Number of rows should match").hasSameSizeAs(expected); + assertThat(actual).as("Result rows should match").isEqualTo(expected); + assertThat(table.snapshots()).as("Number of snapshots should match").hasSize(2); + } finally { + for (StreamingQuery query : spark.streams().active()) { + query.stop(); + } + } + } + + @Test + public void testStreamingWriteCompleteMode() throws Exception { + File parent = temp.resolve("parquet").toFile(); + File location = new File(parent, "test-table"); + File checkpoint = new File(parent, "checkpoint"); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + List expected = + Lists.newArrayList( + new SimpleRecord(2, "1"), new SimpleRecord(3, "2"), new SimpleRecord(1, "3")); + + MemoryStream inputStream = newMemoryStream(1, spark.sqlContext(), Encoders.INT()); + DataStreamWriter streamWriter = + inputStream + .toDF() + .groupBy("value") + .count() + .selectExpr("CAST(count AS INT) AS id", "CAST (value AS STRING) AS data") + .writeStream() + .outputMode("complete") + .format("iceberg") + .option("checkpointLocation", checkpoint.toString()) + .option("path", location.toString()); + + try { + // start the original query with checkpointing + StreamingQuery query = streamWriter.start(); + List batch1 = Lists.newArrayList(1, 2); + send(batch1, inputStream); + query.processAllAvailable(); + List batch2 = Lists.newArrayList(1, 2, 2, 3); + send(batch2, inputStream); + query.processAllAvailable(); + query.stop(); + + // remove the last commit to force Spark to reprocess batch #1 + File lastCommitFile = new File(checkpoint + "/commits/1"); + deleteFileAndCrc(lastCommitFile); + + // restart the query from the checkpoint + StreamingQuery restartedQuery = streamWriter.start(); + restartedQuery.processAllAvailable(); + + // ensure the write was idempotent + Dataset result = spark.read().format("iceberg").load(location.toString()); + List actual = + result.orderBy("data").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + + assertThat(actual).as("Number of rows should match").hasSameSizeAs(expected); + assertThat(actual).as("Result rows should match").isEqualTo(expected); + assertThat(table.snapshots()).as("Number of snapshots should match").hasSize(2); + } finally { + for (StreamingQuery query : spark.streams().active()) { + query.stop(); + } + } + } + + @Test + public void testStreamingWriteCompleteModeWithProjection() throws Exception { + File parent = temp.resolve("parquet").toFile(); + File location = new File(parent, "test-table"); + File checkpoint = new File(parent, "checkpoint"); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + List expected = + Lists.newArrayList( + new SimpleRecord(1, null), new SimpleRecord(2, null), new SimpleRecord(3, null)); + + MemoryStream inputStream = newMemoryStream(1, spark.sqlContext(), Encoders.INT()); + DataStreamWriter streamWriter = + inputStream + .toDF() + .groupBy("value") + .count() + .selectExpr("CAST(count AS INT) AS id") // select only id column + .writeStream() + .outputMode("complete") + .format("iceberg") + .option("checkpointLocation", checkpoint.toString()) + .option("path", location.toString()); + + try { + // start the original query with checkpointing + StreamingQuery query = streamWriter.start(); + List batch1 = Lists.newArrayList(1, 2); + send(batch1, inputStream); + query.processAllAvailable(); + List batch2 = Lists.newArrayList(1, 2, 2, 3); + send(batch2, inputStream); + query.processAllAvailable(); + query.stop(); + + // remove the last commit to force Spark to reprocess batch #1 + File lastCommitFile = new File(checkpoint + "/commits/1"); + deleteFileAndCrc(lastCommitFile); + + // restart the query from the checkpoint + StreamingQuery restartedQuery = streamWriter.start(); + restartedQuery.processAllAvailable(); + + // ensure the write was idempotent + Dataset result = spark.read().format("iceberg").load(location.toString()); + List actual = + result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + + assertThat(actual).as("Number of rows should match").hasSameSizeAs(expected); + assertThat(actual).as("Result rows should match").isEqualTo(expected); + assertThat(table.snapshots()).as("Number of snapshots should match").hasSize(2); + } finally { + for (StreamingQuery query : spark.streams().active()) { + query.stop(); + } + } + } + + @Test + public void testStreamingWriteUpdateMode() throws Exception { + File parent = temp.resolve("parquet").toFile(); + File location = new File(parent, "test-table"); + File checkpoint = new File(parent, "checkpoint"); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + tables.create(SCHEMA, spec, location.toString()); + + MemoryStream inputStream = newMemoryStream(1, spark.sqlContext(), Encoders.INT()); + DataStreamWriter streamWriter = + inputStream + .toDF() + .selectExpr("value AS id", "CAST (value AS STRING) AS data") + .writeStream() + .outputMode("update") + .format("iceberg") + .option("checkpointLocation", checkpoint.toString()) + .option("path", location.toString()); + + try { + StreamingQuery query = streamWriter.start(); + List batch1 = Lists.newArrayList(1, 2); + send(batch1, inputStream); + + assertThatThrownBy(query::processAllAvailable) + .isInstanceOf(StreamingQueryException.class) + .hasMessageContaining("does not support Update mode"); + } finally { + for (StreamingQuery query : spark.streams().active()) { + query.stop(); + } + } + } + + private MemoryStream newMemoryStream(int id, SQLContext sqlContext, Encoder encoder) { + return new MemoryStream<>(id, sqlContext, Option.empty(), encoder); + } + + private void send(List records, MemoryStream stream) { + stream.addData(JavaConverters.asScalaBuffer(records)); + } + + private void deleteFileAndCrc(File file) throws IOException { + File crcFile = new File(file.getParent(), "." + file.getName() + ".crc"); + if (crcFile.exists()) { + assertThat(crcFile.delete()).as("CRC file must be deleted: " + crcFile.getPath()).isTrue(); + } + assertThat(file.delete()).as("Commit file must be deleted: " + file.getPath()).isTrue(); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestStructuredStreamingRead3.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestStructuredStreamingRead3.java new file mode 100644 index 000000000000..86d65ba0e558 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestStructuredStreamingRead3.java @@ -0,0 +1,779 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.expressions.Expressions.ref; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.io.File; +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.IntStream; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.DataOperations; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Files; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.RewriteFiles; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.TestHelpers; +import org.apache.iceberg.data.FileHelpers; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.CatalogTestBase; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.spark.api.java.function.VoidFunction2; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.streaming.DataStreamWriter; +import org.apache.spark.sql.streaming.OutputMode; +import org.apache.spark.sql.streaming.StreamingQuery; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public final class TestStructuredStreamingRead3 extends CatalogTestBase { + + private Table table; + + private final AtomicInteger microBatches = new AtomicInteger(); + + /** + * test data to be used by multiple writes each write creates a snapshot and writes a list of + * records + */ + private static final List> TEST_DATA_MULTIPLE_SNAPSHOTS = + Lists.newArrayList( + Lists.newArrayList( + new SimpleRecord(1, "one"), new SimpleRecord(2, "two"), new SimpleRecord(3, "three")), + Lists.newArrayList(new SimpleRecord(4, "four"), new SimpleRecord(5, "five")), + Lists.newArrayList(new SimpleRecord(6, "six"), new SimpleRecord(7, "seven"))); + + /** + * test data - to be used for multiple write batches each batch inturn will have multiple + * snapshots + */ + private static final List>> TEST_DATA_MULTIPLE_WRITES_MULTIPLE_SNAPSHOTS = + Lists.newArrayList( + Lists.newArrayList( + Lists.newArrayList( + new SimpleRecord(1, "one"), + new SimpleRecord(2, "two"), + new SimpleRecord(3, "three")), + Lists.newArrayList(new SimpleRecord(4, "four"), new SimpleRecord(5, "five"))), + Lists.newArrayList( + Lists.newArrayList(new SimpleRecord(6, "six"), new SimpleRecord(7, "seven")), + Lists.newArrayList(new SimpleRecord(8, "eight"), new SimpleRecord(9, "nine"))), + Lists.newArrayList( + Lists.newArrayList( + new SimpleRecord(10, "ten"), + new SimpleRecord(11, "eleven"), + new SimpleRecord(12, "twelve")), + Lists.newArrayList( + new SimpleRecord(13, "thirteen"), new SimpleRecord(14, "fourteen")), + Lists.newArrayList( + new SimpleRecord(15, "fifteen"), new SimpleRecord(16, "sixteen")))); + + @BeforeAll + public static void setupSpark() { + // disable AQE as tests assume that writes generate a particular number of files + spark.conf().set(SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), "false"); + } + + @BeforeEach + public void setupTable() { + sql( + "CREATE TABLE %s " + + "(id INT, data STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(3, id)) " + + "TBLPROPERTIES ('commit.manifest.min-count-to-merge'='3', 'commit.manifest-merge.enabled'='true')", + tableName); + this.table = validationCatalog.loadTable(tableIdent); + microBatches.set(0); + } + + @AfterEach + public void stopStreams() throws TimeoutException { + for (StreamingQuery query : spark.streams().active()) { + query.stop(); + } + } + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testReadStreamOnIcebergTableWithMultipleSnapshots() throws Exception { + List> expected = TEST_DATA_MULTIPLE_SNAPSHOTS; + appendDataAsMultipleSnapshots(expected); + + StreamingQuery query = startStream(); + + List actual = rowsAvailable(query); + assertThat(actual).containsExactlyInAnyOrderElementsOf(Iterables.concat(expected)); + } + + @TestTemplate + public void testReadStreamOnIcebergTableWithMultipleSnapshots_WithNumberOfFiles_1() + throws Exception { + appendDataAsMultipleSnapshots(TEST_DATA_MULTIPLE_SNAPSHOTS); + + assertThat( + microBatchCount( + ImmutableMap.of(SparkReadOptions.STREAMING_MAX_FILES_PER_MICRO_BATCH, "1"))) + .isEqualTo(6); + } + + @TestTemplate + public void testReadStreamOnIcebergTableWithMultipleSnapshots_WithNumberOfFiles_2() + throws Exception { + appendDataAsMultipleSnapshots(TEST_DATA_MULTIPLE_SNAPSHOTS); + + assertThat( + microBatchCount( + ImmutableMap.of(SparkReadOptions.STREAMING_MAX_FILES_PER_MICRO_BATCH, "2"))) + .isEqualTo(3); + } + + @TestTemplate + public void testReadStreamOnIcebergTableWithMultipleSnapshots_WithNumberOfRows_1() + throws Exception { + appendDataAsMultipleSnapshots(TEST_DATA_MULTIPLE_SNAPSHOTS); + + // only 1 micro-batch will be formed and we will read data partially + assertThat( + microBatchCount( + ImmutableMap.of(SparkReadOptions.STREAMING_MAX_ROWS_PER_MICRO_BATCH, "1"))) + .isEqualTo(1); + + StreamingQuery query = startStream(SparkReadOptions.STREAMING_MAX_ROWS_PER_MICRO_BATCH, "1"); + + // check answer correctness only 1 record read the micro-batch will be stuck + List actual = rowsAvailable(query); + assertThat(actual) + .containsExactlyInAnyOrderElementsOf( + Lists.newArrayList(TEST_DATA_MULTIPLE_SNAPSHOTS.get(0).get(0))); + } + + @TestTemplate + public void testReadStreamOnIcebergTableWithMultipleSnapshots_WithNumberOfRows_4() + throws Exception { + appendDataAsMultipleSnapshots(TEST_DATA_MULTIPLE_SNAPSHOTS); + + assertThat( + microBatchCount( + ImmutableMap.of(SparkReadOptions.STREAMING_MAX_ROWS_PER_MICRO_BATCH, "4"))) + .isEqualTo(2); + } + + @TestTemplate + public void testReadStreamOnIcebergThenAddData() throws Exception { + List> expected = TEST_DATA_MULTIPLE_SNAPSHOTS; + + StreamingQuery query = startStream(); + + appendDataAsMultipleSnapshots(expected); + + List actual = rowsAvailable(query); + assertThat(actual).containsExactlyInAnyOrderElementsOf(Iterables.concat(expected)); + } + + @TestTemplate + public void testReadingStreamFromTimestamp() throws Exception { + List dataBeforeTimestamp = + Lists.newArrayList( + new SimpleRecord(-2, "minustwo"), + new SimpleRecord(-1, "minusone"), + new SimpleRecord(0, "zero")); + + appendData(dataBeforeTimestamp); + + table.refresh(); + long streamStartTimestamp = table.currentSnapshot().timestampMillis() + 1; + + StreamingQuery query = + startStream(SparkReadOptions.STREAM_FROM_TIMESTAMP, Long.toString(streamStartTimestamp)); + + List empty = rowsAvailable(query); + assertThat(empty).isEmpty(); + + List> expected = TEST_DATA_MULTIPLE_SNAPSHOTS; + appendDataAsMultipleSnapshots(expected); + + List actual = rowsAvailable(query); + + assertThat(actual).containsExactlyInAnyOrderElementsOf(Iterables.concat(expected)); + } + + @TestTemplate + public void testReadingStreamFromFutureTimetsamp() throws Exception { + long futureTimestamp = System.currentTimeMillis() + 10000; + + StreamingQuery query = + startStream(SparkReadOptions.STREAM_FROM_TIMESTAMP, Long.toString(futureTimestamp)); + + List actual = rowsAvailable(query); + assertThat(actual).isEmpty(); + + List data = + Lists.newArrayList( + new SimpleRecord(-2, "minustwo"), + new SimpleRecord(-1, "minusone"), + new SimpleRecord(0, "zero")); + + // Perform several inserts that should not show up because the fromTimestamp has not elapsed + IntStream.range(0, 3) + .forEach( + x -> { + appendData(data); + assertThat(rowsAvailable(query)).isEmpty(); + }); + + waitUntilAfter(futureTimestamp); + + // Data appended after the timestamp should appear + appendData(data); + actual = rowsAvailable(query); + assertThat(actual).containsExactlyInAnyOrderElementsOf(data); + } + + @TestTemplate + public void testReadingStreamFromTimestampFutureWithExistingSnapshots() throws Exception { + List dataBeforeTimestamp = + Lists.newArrayList( + new SimpleRecord(1, "one"), new SimpleRecord(2, "two"), new SimpleRecord(3, "three")); + appendData(dataBeforeTimestamp); + + long streamStartTimestamp = System.currentTimeMillis() + 2000; + + // Start the stream with a future timestamp after the current snapshot + StreamingQuery query = + startStream(SparkReadOptions.STREAM_FROM_TIMESTAMP, Long.toString(streamStartTimestamp)); + List actual = rowsAvailable(query); + assertThat(actual).isEmpty(); + + // Stream should contain data added after the timestamp elapses + waitUntilAfter(streamStartTimestamp); + List> expected = TEST_DATA_MULTIPLE_SNAPSHOTS; + appendDataAsMultipleSnapshots(expected); + assertThat(rowsAvailable(query)) + .containsExactlyInAnyOrderElementsOf(Iterables.concat(expected)); + } + + @TestTemplate + public void testReadingStreamFromTimestampOfExistingSnapshot() throws Exception { + List> expected = TEST_DATA_MULTIPLE_SNAPSHOTS; + + // Create an existing snapshot with some data + appendData(expected.get(0)); + table.refresh(); + long firstSnapshotTime = table.currentSnapshot().timestampMillis(); + + // Start stream giving the first Snapshot's time as the start point + StreamingQuery stream = + startStream(SparkReadOptions.STREAM_FROM_TIMESTAMP, Long.toString(firstSnapshotTime)); + + // Append rest of expected data + for (int i = 1; i < expected.size(); i++) { + appendData(expected.get(i)); + } + + List actual = rowsAvailable(stream); + assertThat(actual).containsExactlyInAnyOrderElementsOf(Iterables.concat(expected)); + } + + @TestTemplate + public void testReadingStreamWithExpiredSnapshotFromTimestamp() throws TimeoutException { + List firstSnapshotRecordList = Lists.newArrayList(new SimpleRecord(1, "one")); + + List secondSnapshotRecordList = Lists.newArrayList(new SimpleRecord(2, "two")); + + List thirdSnapshotRecordList = Lists.newArrayList(new SimpleRecord(3, "three")); + + List expectedRecordList = Lists.newArrayList(); + expectedRecordList.addAll(secondSnapshotRecordList); + expectedRecordList.addAll(thirdSnapshotRecordList); + + appendData(firstSnapshotRecordList); + table.refresh(); + long firstSnapshotid = table.currentSnapshot().snapshotId(); + long firstSnapshotCommitTime = table.currentSnapshot().timestampMillis(); + + appendData(secondSnapshotRecordList); + appendData(thirdSnapshotRecordList); + + table.expireSnapshots().expireSnapshotId(firstSnapshotid).commit(); + + StreamingQuery query = + startStream( + SparkReadOptions.STREAM_FROM_TIMESTAMP, String.valueOf(firstSnapshotCommitTime)); + List actual = rowsAvailable(query); + assertThat(actual).containsExactlyInAnyOrderElementsOf(expectedRecordList); + } + + @TestTemplate + public void testResumingStreamReadFromCheckpoint() throws Exception { + File writerCheckpointFolder = temp.resolve("writer-checkpoint-folder").toFile(); + File writerCheckpoint = new File(writerCheckpointFolder, "writer-checkpoint"); + File output = temp.resolve("junit").toFile(); + + DataStreamWriter querySource = + spark + .readStream() + .format("iceberg") + .load(tableName) + .writeStream() + .option("checkpointLocation", writerCheckpoint.toString()) + .format("parquet") + .queryName("checkpoint_test") + .option("path", output.getPath()); + + StreamingQuery startQuery = querySource.start(); + startQuery.processAllAvailable(); + startQuery.stop(); + + List expected = Lists.newArrayList(); + for (List> expectedCheckpoint : + TEST_DATA_MULTIPLE_WRITES_MULTIPLE_SNAPSHOTS) { + // New data was added while the stream was down + appendDataAsMultipleSnapshots(expectedCheckpoint); + expected.addAll(Lists.newArrayList(Iterables.concat(Iterables.concat(expectedCheckpoint)))); + + // Stream starts up again from checkpoint read the newly added data and shut down + StreamingQuery restartedQuery = querySource.start(); + restartedQuery.processAllAvailable(); + restartedQuery.stop(); + + // Read data added by the stream + List actual = + spark.read().load(output.getPath()).as(Encoders.bean(SimpleRecord.class)).collectAsList(); + assertThat(actual).containsExactlyInAnyOrderElementsOf(Iterables.concat(expected)); + } + } + + @TestTemplate + public void testFailReadingCheckpointInvalidSnapshot() throws IOException, TimeoutException { + File writerCheckpointFolder = temp.resolve("writer-checkpoint-folder").toFile(); + File writerCheckpoint = new File(writerCheckpointFolder, "writer-checkpoint"); + File output = temp.resolve("junit").toFile(); + + DataStreamWriter querySource = + spark + .readStream() + .format("iceberg") + .load(tableName) + .writeStream() + .option("checkpointLocation", writerCheckpoint.toString()) + .format("parquet") + .queryName("checkpoint_test") + .option("path", output.getPath()); + + List firstSnapshotRecordList = Lists.newArrayList(new SimpleRecord(1, "one")); + List secondSnapshotRecordList = Lists.newArrayList(new SimpleRecord(2, "two")); + StreamingQuery startQuery = querySource.start(); + + appendData(firstSnapshotRecordList); + table.refresh(); + long firstSnapshotid = table.currentSnapshot().snapshotId(); + startQuery.processAllAvailable(); + startQuery.stop(); + + appendData(secondSnapshotRecordList); + + table.expireSnapshots().expireSnapshotId(firstSnapshotid).commit(); + + StreamingQuery restartedQuery = querySource.start(); + assertThatThrownBy(restartedQuery::processAllAvailable) + .hasCauseInstanceOf(IllegalStateException.class) + .hasMessageContaining( + String.format( + "Cannot load current offset at snapshot %d, the snapshot was expired or removed", + firstSnapshotid)); + } + + @TestTemplate + public void testParquetOrcAvroDataInOneTable() throws Exception { + List parquetFileRecords = + Lists.newArrayList( + new SimpleRecord(1, "one"), new SimpleRecord(2, "two"), new SimpleRecord(3, "three")); + + List orcFileRecords = + Lists.newArrayList(new SimpleRecord(4, "four"), new SimpleRecord(5, "five")); + + List avroFileRecords = + Lists.newArrayList(new SimpleRecord(6, "six"), new SimpleRecord(7, "seven")); + + appendData(parquetFileRecords); + appendData(orcFileRecords, "orc"); + appendData(avroFileRecords, "avro"); + + StreamingQuery query = startStream(); + assertThat(rowsAvailable(query)) + .containsExactlyInAnyOrderElementsOf( + Iterables.concat(parquetFileRecords, orcFileRecords, avroFileRecords)); + } + + @TestTemplate + public void testReadStreamFromEmptyTable() throws Exception { + StreamingQuery stream = startStream(); + List actual = rowsAvailable(stream); + assertThat(actual).isEmpty(); + } + + @TestTemplate + public void testReadStreamWithSnapshotTypeOverwriteErrorsOut() throws Exception { + // upgrade table to version 2 - to facilitate creation of Snapshot of type OVERWRITE. + TableOperations ops = ((BaseTable) table).operations(); + TableMetadata meta = ops.current(); + ops.commit(meta, meta.upgradeToFormatVersion(2)); + + // fill table with some initial data + List> dataAcrossSnapshots = TEST_DATA_MULTIPLE_SNAPSHOTS; + appendDataAsMultipleSnapshots(dataAcrossSnapshots); + + Schema deleteRowSchema = table.schema().select("data"); + Record dataDelete = GenericRecord.create(deleteRowSchema); + List dataDeletes = + Lists.newArrayList( + dataDelete.copy("data", "one") // id = 1 + ); + + DeleteFile eqDeletes = + FileHelpers.writeDeleteFile( + table, + Files.localOutput(File.createTempFile("junit", null, temp.toFile())), + TestHelpers.Row.of(0), + dataDeletes, + deleteRowSchema); + + DataFile dataFile = + DataFiles.builder(table.spec()) + .withPath(File.createTempFile("junit", null, temp.toFile()).getPath()) + .withFileSizeInBytes(10) + .withRecordCount(1) + .withFormat(FileFormat.PARQUET) + .build(); + + table.newRowDelta().addRows(dataFile).addDeletes(eqDeletes).commit(); + + // check pre-condition - that the above Delete file write - actually resulted in snapshot of + // type OVERWRITE + assertThat(table.currentSnapshot().operation()).isEqualTo(DataOperations.OVERWRITE); + + StreamingQuery query = startStream(); + + assertThatThrownBy(query::processAllAvailable) + .cause() + .isInstanceOf(IllegalStateException.class) + .hasMessageStartingWith("Cannot process overwrite snapshot"); + } + + @TestTemplate + public void testReadStreamWithSnapshotTypeRewriteDataFilesIgnoresReplace() throws Exception { + // fill table with some data + List> expected = TEST_DATA_MULTIPLE_SNAPSHOTS; + appendDataAsMultipleSnapshots(expected); + + makeRewriteDataFiles(); + + assertThat( + microBatchCount( + ImmutableMap.of(SparkReadOptions.STREAMING_MAX_FILES_PER_MICRO_BATCH, "1"))) + .isEqualTo(6); + } + + @TestTemplate + public void testReadStreamWithSnapshotTypeRewriteDataFilesIgnoresReplaceMaxRows() + throws Exception { + // fill table with some data + List> expected = TEST_DATA_MULTIPLE_SNAPSHOTS; + appendDataAsMultipleSnapshots(expected); + + makeRewriteDataFiles(); + + assertThat( + microBatchCount( + ImmutableMap.of(SparkReadOptions.STREAMING_MAX_ROWS_PER_MICRO_BATCH, "4"))) + .isEqualTo(2); + } + + @TestTemplate + public void testReadStreamWithSnapshotTypeRewriteDataFilesIgnoresReplaceMaxFilesAndRows() + throws Exception { + // fill table with some data + List> expected = TEST_DATA_MULTIPLE_SNAPSHOTS; + appendDataAsMultipleSnapshots(expected); + + makeRewriteDataFiles(); + + assertThat( + microBatchCount( + ImmutableMap.of( + SparkReadOptions.STREAMING_MAX_ROWS_PER_MICRO_BATCH, + "4", + SparkReadOptions.STREAMING_MAX_FILES_PER_MICRO_BATCH, + "1"))) + .isEqualTo(6); + } + + @TestTemplate + public void testReadStreamWithSnapshotType2RewriteDataFilesIgnoresReplace() throws Exception { + // fill table with some data + List> expected = TEST_DATA_MULTIPLE_SNAPSHOTS; + appendDataAsMultipleSnapshots(expected); + + makeRewriteDataFiles(); + makeRewriteDataFiles(); + + assertThat( + microBatchCount( + ImmutableMap.of(SparkReadOptions.STREAMING_MAX_FILES_PER_MICRO_BATCH, "1"))) + .isEqualTo(6); + } + + @TestTemplate + public void testReadStreamWithSnapshotTypeRewriteDataFilesIgnoresReplaceFollowedByAppend() + throws Exception { + // fill table with some data + List> expected = TEST_DATA_MULTIPLE_SNAPSHOTS; + appendDataAsMultipleSnapshots(expected); + + makeRewriteDataFiles(); + + appendDataAsMultipleSnapshots(expected); + + assertThat( + microBatchCount( + ImmutableMap.of(SparkReadOptions.STREAMING_MAX_FILES_PER_MICRO_BATCH, "1"))) + .isEqualTo(12); + } + + @TestTemplate + public void testReadStreamWithSnapshotTypeReplaceIgnoresReplace() throws Exception { + // fill table with some data + List> expected = TEST_DATA_MULTIPLE_SNAPSHOTS; + appendDataAsMultipleSnapshots(expected); + + // this should create a snapshot with type Replace. + table.rewriteManifests().clusterBy(f -> 1).commit(); + + // check pre-condition + assertThat(table.currentSnapshot().operation()).isEqualTo(DataOperations.REPLACE); + + StreamingQuery query = startStream(); + List actual = rowsAvailable(query); + assertThat(actual).containsExactlyInAnyOrderElementsOf(Iterables.concat(expected)); + } + + @TestTemplate + public void testReadStreamWithSnapshotTypeDeleteErrorsOut() throws Exception { + table.updateSpec().removeField("id_bucket").addField(ref("id")).commit(); + + // fill table with some data + List> dataAcrossSnapshots = TEST_DATA_MULTIPLE_SNAPSHOTS; + appendDataAsMultipleSnapshots(dataAcrossSnapshots); + + // this should create a snapshot with type delete. + table.newDelete().deleteFromRowFilter(Expressions.equal("id", 4)).commit(); + + // check pre-condition - that the above delete operation on table resulted in Snapshot of Type + // DELETE. + assertThat(table.currentSnapshot().operation()).isEqualTo(DataOperations.DELETE); + + StreamingQuery query = startStream(); + + assertThatThrownBy(query::processAllAvailable) + .cause() + .isInstanceOf(IllegalStateException.class) + .hasMessageStartingWith("Cannot process delete snapshot"); + } + + @TestTemplate + public void testReadStreamWithSnapshotTypeDeleteAndSkipDeleteOption() throws Exception { + table.updateSpec().removeField("id_bucket").addField(ref("id")).commit(); + + // fill table with some data + List> dataAcrossSnapshots = TEST_DATA_MULTIPLE_SNAPSHOTS; + appendDataAsMultipleSnapshots(dataAcrossSnapshots); + + // this should create a snapshot with type delete. + table.newDelete().deleteFromRowFilter(Expressions.equal("id", 4)).commit(); + + // check pre-condition - that the above delete operation on table resulted in Snapshot of Type + // DELETE. + assertThat(table.currentSnapshot().operation()).isEqualTo(DataOperations.DELETE); + + StreamingQuery query = startStream(SparkReadOptions.STREAMING_SKIP_DELETE_SNAPSHOTS, "true"); + assertThat(rowsAvailable(query)) + .containsExactlyInAnyOrderElementsOf(Iterables.concat(dataAcrossSnapshots)); + } + + @TestTemplate + public void testReadStreamWithSnapshotTypeDeleteAndSkipOverwriteOption() throws Exception { + table.updateSpec().removeField("id_bucket").addField(ref("id")).commit(); + + // fill table with some data + List> dataAcrossSnapshots = TEST_DATA_MULTIPLE_SNAPSHOTS; + appendDataAsMultipleSnapshots(dataAcrossSnapshots); + + DataFile dataFile = + DataFiles.builder(table.spec()) + .withPath(File.createTempFile("junit", null, temp.toFile()).getPath()) + .withFileSizeInBytes(10) + .withRecordCount(1) + .withFormat(FileFormat.PARQUET) + .build(); + + // this should create a snapshot with type overwrite. + table + .newOverwrite() + .addFile(dataFile) + .overwriteByRowFilter(Expressions.greaterThan("id", 4)) + .commit(); + + // check pre-condition - that the above delete operation on table resulted in Snapshot of Type + // OVERWRITE. + assertThat(table.currentSnapshot().operation()).isEqualTo(DataOperations.OVERWRITE); + + StreamingQuery query = startStream(SparkReadOptions.STREAMING_SKIP_OVERWRITE_SNAPSHOTS, "true"); + assertThat(rowsAvailable(query)) + .containsExactlyInAnyOrderElementsOf(Iterables.concat(dataAcrossSnapshots)); + } + + /** + * We are testing that all the files in a rewrite snapshot are skipped Create a rewrite data files + * snapshot using existing files. + */ + public void makeRewriteDataFiles() { + table.refresh(); + + // we are testing that all the files in a rewrite snapshot are skipped + // create a rewrite data files snapshot using existing files + RewriteFiles rewrite = table.newRewrite(); + Iterable it = table.snapshots(); + for (Snapshot snapshot : it) { + if (snapshot.operation().equals(DataOperations.APPEND)) { + Iterable datafiles = snapshot.addedDataFiles(table.io()); + for (DataFile datafile : datafiles) { + rewrite.addFile(datafile); + rewrite.deleteFile(datafile); + } + } + } + rewrite.commit(); + } + + /** + * appends each list as a Snapshot on the iceberg table at the given location. accepts a list of + * lists - each list representing data per snapshot. + */ + private void appendDataAsMultipleSnapshots(List> data) { + for (List l : data) { + appendData(l); + } + } + + private void appendData(List data) { + appendData(data, "parquet"); + } + + private void appendData(List data, String format) { + Dataset df = spark.createDataFrame(data, SimpleRecord.class); + df.select("id", "data") + .write() + .format("iceberg") + .option("write-format", format) + .mode("append") + .save(tableName); + } + + private static final String MEMORY_TABLE = "_stream_view_mem"; + + private StreamingQuery startStream(Map options) throws TimeoutException { + return spark + .readStream() + .options(options) + .format("iceberg") + .load(tableName) + .writeStream() + .options(options) + .format("memory") + .queryName(MEMORY_TABLE) + .outputMode(OutputMode.Append()) + .start(); + } + + private StreamingQuery startStream() throws TimeoutException { + return startStream(Collections.emptyMap()); + } + + private StreamingQuery startStream(String key, String value) throws TimeoutException { + return startStream( + ImmutableMap.of(key, value, SparkReadOptions.STREAMING_MAX_FILES_PER_MICRO_BATCH, "1")); + } + + private int microBatchCount(Map options) throws TimeoutException { + Dataset ds = spark.readStream().options(options).format("iceberg").load(tableName); + + ds.writeStream() + .options(options) + .foreachBatch( + (VoidFunction2, Long>) + (dataset, batchId) -> { + microBatches.getAndIncrement(); + }) + .start() + .processAllAvailable(); + + stopStreams(); + return microBatches.get(); + } + + private List rowsAvailable(StreamingQuery query) { + query.processAllAvailable(); + return spark + .sql("select * from " + MEMORY_TABLE) + .as(Encoders.bean(SimpleRecord.class)) + .collectAsList(); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestTables.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestTables.java new file mode 100644 index 000000000000..b54bb315c543 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestTables.java @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.io.File; +import java.util.Map; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.Files; +import org.apache.iceberg.LocationProviders; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.exceptions.AlreadyExistsException; +import org.apache.iceberg.exceptions.CommitFailedException; +import org.apache.iceberg.exceptions.RuntimeIOException; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.io.LocationProvider; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; + +// TODO: Use the copy of this from core. +class TestTables { + private TestTables() {} + + static TestTable create(File temp, String name, Schema schema, PartitionSpec spec) { + return create(temp, name, schema, spec, ImmutableMap.of()); + } + + static TestTable create( + File temp, String name, Schema schema, PartitionSpec spec, Map properties) { + TestTableOperations ops = new TestTableOperations(name); + if (ops.current() != null) { + throw new AlreadyExistsException("Table %s already exists at location: %s", name, temp); + } + ops.commit(null, TableMetadata.newTableMetadata(schema, spec, temp.toString(), properties)); + return new TestTable(ops, name); + } + + static TestTable load(String name) { + TestTableOperations ops = new TestTableOperations(name); + if (ops.current() == null) { + return null; + } + return new TestTable(ops, name); + } + + static boolean drop(String name) { + synchronized (METADATA) { + return METADATA.remove(name) != null; + } + } + + static class TestTable extends BaseTable { + private final TestTableOperations ops; + + private TestTable(TestTableOperations ops, String name) { + super(ops, name); + this.ops = ops; + } + + @Override + public TestTableOperations operations() { + return ops; + } + } + + private static final Map METADATA = Maps.newHashMap(); + + static void clearTables() { + synchronized (METADATA) { + METADATA.clear(); + } + } + + static TableMetadata readMetadata(String tableName) { + synchronized (METADATA) { + return METADATA.get(tableName); + } + } + + static void replaceMetadata(String tableName, TableMetadata metadata) { + synchronized (METADATA) { + METADATA.put(tableName, metadata); + } + } + + static class TestTableOperations implements TableOperations { + + private final String tableName; + private TableMetadata current = null; + private long lastSnapshotId = 0; + private int failCommits = 0; + + TestTableOperations(String tableName) { + this.tableName = tableName; + refresh(); + if (current != null) { + for (Snapshot snap : current.snapshots()) { + this.lastSnapshotId = Math.max(lastSnapshotId, snap.snapshotId()); + } + } else { + this.lastSnapshotId = 0; + } + } + + void failCommits(int numFailures) { + this.failCommits = numFailures; + } + + @Override + public TableMetadata current() { + return current; + } + + @Override + public TableMetadata refresh() { + synchronized (METADATA) { + this.current = METADATA.get(tableName); + } + return current; + } + + @Override + public void commit(TableMetadata base, TableMetadata metadata) { + if (base != current) { + throw new CommitFailedException("Cannot commit changes based on stale metadata"); + } + synchronized (METADATA) { + refresh(); + if (base == current) { + if (failCommits > 0) { + this.failCommits -= 1; + throw new CommitFailedException("Injected failure"); + } + METADATA.put(tableName, metadata); + this.current = metadata; + } else { + throw new CommitFailedException( + "Commit failed: table was updated at %d", base.lastUpdatedMillis()); + } + } + } + + @Override + public FileIO io() { + return new LocalFileIO(); + } + + @Override + public LocationProvider locationProvider() { + Preconditions.checkNotNull( + current, "Current metadata should not be null when locationProvider is called"); + return LocationProviders.locationsFor(current.location(), current.properties()); + } + + @Override + public String metadataFileLocation(String fileName) { + return new File(new File(current.location(), "metadata"), fileName).getAbsolutePath(); + } + + @Override + public long newSnapshotId() { + long nextSnapshotId = lastSnapshotId + 1; + this.lastSnapshotId = nextSnapshotId; + return nextSnapshotId; + } + } + + static class LocalFileIO implements FileIO { + + @Override + public InputFile newInputFile(String path) { + return Files.localInput(path); + } + + @Override + public OutputFile newOutputFile(String path) { + return Files.localOutput(new File(path)); + } + + @Override + public void deleteFile(String path) { + if (!new File(path).delete()) { + throw new RuntimeIOException("Failed to delete file: " + path); + } + } + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestTimestampWithoutZone.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestTimestampWithoutZone.java new file mode 100644 index 000000000000..306444b9f29f --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestTimestampWithoutZone.java @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.Files.localOutput; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.time.LocalDateTime; +import java.util.List; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.data.GenericAppenderFactory; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.TestBase; +import org.apache.iceberg.spark.data.GenericsHelpers; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestTimestampWithoutZone extends TestBase { + private static final Configuration CONF = new Configuration(); + private static final HadoopTables TABLES = new HadoopTables(CONF); + + private static final Schema SCHEMA = + new Schema( + Types.NestedField.required(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "ts", Types.TimestampType.withoutZone()), + Types.NestedField.optional(3, "data", Types.StringType.get())); + + private static SparkSession spark = null; + + @BeforeAll + public static void startSpark() { + TestTimestampWithoutZone.spark = SparkSession.builder().master("local[2]").getOrCreate(); + } + + @AfterAll + public static void stopSpark() { + SparkSession currentSpark = TestTimestampWithoutZone.spark; + TestTimestampWithoutZone.spark = null; + currentSpark.stop(); + } + + @TempDir private Path temp; + + @Parameter(index = 0) + private String format; + + @Parameter(index = 1) + private boolean vectorized; + + @Parameters(name = "format = {0}, vectorized = {1}") + public static Object[][] parameters() { + return new Object[][] { + {"parquet", false}, + {"parquet", true}, + {"avro", false} + }; + } + + private File parent = null; + private File unpartitioned = null; + private List records = null; + + @BeforeEach + public void writeUnpartitionedTable() throws IOException { + this.parent = temp.resolve("TestTimestampWithoutZone").toFile(); + this.unpartitioned = new File(parent, "unpartitioned"); + File dataFolder = new File(unpartitioned, "data"); + assertThat(dataFolder.mkdirs()).as("Mkdir should succeed").isTrue(); + + Table table = TABLES.create(SCHEMA, PartitionSpec.unpartitioned(), unpartitioned.toString()); + Schema tableSchema = table.schema(); // use the table schema because ids are reassigned + + FileFormat fileFormat = FileFormat.fromString(format); + + File testFile = new File(dataFolder, fileFormat.addExtension(UUID.randomUUID().toString())); + + // create records using the table's schema + this.records = testRecords(tableSchema); + + try (FileAppender writer = + new GenericAppenderFactory(tableSchema).newAppender(localOutput(testFile), fileFormat)) { + writer.addAll(records); + } + + DataFile file = + DataFiles.builder(PartitionSpec.unpartitioned()) + .withRecordCount(records.size()) + .withFileSizeInBytes(testFile.length()) + .withPath(testFile.toString()) + .build(); + + table.newAppend().appendFile(file).commit(); + } + + @TestTemplate + public void testUnpartitionedTimestampWithoutZone() { + assertEqualsSafe(SCHEMA.asStruct(), records, read(unpartitioned.toString(), vectorized)); + } + + @TestTemplate + public void testUnpartitionedTimestampWithoutZoneProjection() { + Schema projection = SCHEMA.select("id", "ts"); + assertEqualsSafe( + projection.asStruct(), + records.stream().map(r -> projectFlat(projection, r)).collect(Collectors.toList()), + read(unpartitioned.toString(), vectorized, "id", "ts")); + } + + @TestTemplate + public void testUnpartitionedTimestampWithoutZoneAppend() { + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(unpartitioned.toString()) + .write() + .format("iceberg") + .mode(SaveMode.Append) + .save(unpartitioned.toString()); + + assertEqualsSafe( + SCHEMA.asStruct(), + Stream.concat(records.stream(), records.stream()).collect(Collectors.toList()), + read(unpartitioned.toString(), vectorized)); + } + + private static Record projectFlat(Schema projection, Record record) { + Record result = GenericRecord.create(projection); + List fields = projection.asStruct().fields(); + for (int i = 0; i < fields.size(); i += 1) { + Types.NestedField field = fields.get(i); + result.set(i, record.getField(field.name())); + } + return result; + } + + public static void assertEqualsSafe( + Types.StructType struct, List expected, List actual) { + assertThat(actual).as("Number of results should match expected").hasSameSizeAs(expected); + for (int i = 0; i < expected.size(); i += 1) { + GenericsHelpers.assertEqualsSafe(struct, expected.get(i), actual.get(i)); + } + } + + private List testRecords(Schema schema) { + return Lists.newArrayList( + record(schema, 0L, parseToLocal("2017-12-22T09:20:44.294658"), "junction"), + record(schema, 1L, parseToLocal("2017-12-22T07:15:34.582910"), "alligator"), + record(schema, 2L, parseToLocal("2017-12-22T06:02:09.243857"), "forrest"), + record(schema, 3L, parseToLocal("2017-12-22T03:10:11.134509"), "clapping"), + record(schema, 4L, parseToLocal("2017-12-22T00:34:00.184671"), "brush"), + record(schema, 5L, parseToLocal("2017-12-21T22:20:08.935889"), "trap"), + record(schema, 6L, parseToLocal("2017-12-21T21:55:30.589712"), "element"), + record(schema, 7L, parseToLocal("2017-12-21T17:31:14.532797"), "limited"), + record(schema, 8L, parseToLocal("2017-12-21T15:21:51.237521"), "global"), + record(schema, 9L, parseToLocal("2017-12-21T15:02:15.230570"), "goldfish")); + } + + private static List read(String table, boolean vectorized) { + return read(table, vectorized, "*"); + } + + private static List read( + String table, boolean vectorized, String select0, String... selectN) { + Dataset dataset = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(table) + .select(select0, selectN); + return dataset.collectAsList(); + } + + private static LocalDateTime parseToLocal(String timestamp) { + return LocalDateTime.parse(timestamp); + } + + private static Record record(Schema schema, Object... values) { + Record rec = GenericRecord.create(schema); + for (int i = 0; i < values.length; i += 1) { + rec.set(i, values[i]); + } + return rec; + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestWriteMetricsConfig.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestWriteMetricsConfig.java new file mode 100644 index 000000000000..841268a6be0e --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/TestWriteMetricsConfig.java @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.spark.SparkSchemaUtil.convert; +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.ByteBuffers; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +public class TestWriteMetricsConfig { + + private static final Configuration CONF = new Configuration(); + private static final Schema SIMPLE_SCHEMA = + new Schema( + optional(1, "id", Types.IntegerType.get()), optional(2, "data", Types.StringType.get())); + private static final Schema COMPLEX_SCHEMA = + new Schema( + required(1, "longCol", Types.IntegerType.get()), + optional(2, "strCol", Types.StringType.get()), + required( + 3, + "record", + Types.StructType.of( + required(4, "id", Types.IntegerType.get()), + required(5, "data", Types.StringType.get())))); + + @TempDir private Path temp; + + private static SparkSession spark = null; + private static JavaSparkContext sc = null; + + @BeforeAll + public static void startSpark() { + TestWriteMetricsConfig.spark = SparkSession.builder().master("local[2]").getOrCreate(); + TestWriteMetricsConfig.sc = JavaSparkContext.fromSparkContext(spark.sparkContext()); + } + + @AfterAll + public static void stopSpark() { + SparkSession currentSpark = TestWriteMetricsConfig.spark; + TestWriteMetricsConfig.spark = null; + TestWriteMetricsConfig.sc = null; + currentSpark.stop(); + } + + @Test + public void testFullMetricsCollectionForParquet() throws IOException { + String tableLocation = temp.resolve("iceberg-table").toFile().toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.DEFAULT_WRITE_METRICS_MODE, "full"); + Table table = tables.create(SIMPLE_SCHEMA, spec, properties, tableLocation); + + List expectedRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset df = spark.createDataFrame(expectedRecords, SimpleRecord.class); + df.select("id", "data") + .coalesce(1) + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, "parquet") + .mode(SaveMode.Append) + .save(tableLocation); + + for (FileScanTask task : table.newScan().includeColumnStats().planFiles()) { + DataFile file = task.file(); + + assertThat(file.nullValueCounts()).hasSize(2); + assertThat(file.valueCounts()).hasSize(2); + assertThat(file.lowerBounds()).hasSize(2); + assertThat(file.upperBounds()).hasSize(2); + } + } + + @Test + public void testCountMetricsCollectionForParquet() throws IOException { + String tableLocation = temp.resolve("iceberg-table").toFile().toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.DEFAULT_WRITE_METRICS_MODE, "counts"); + Table table = tables.create(SIMPLE_SCHEMA, spec, properties, tableLocation); + + List expectedRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset df = spark.createDataFrame(expectedRecords, SimpleRecord.class); + df.select("id", "data") + .coalesce(1) + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, "parquet") + .mode(SaveMode.Append) + .save(tableLocation); + + for (FileScanTask task : table.newScan().includeColumnStats().planFiles()) { + DataFile file = task.file(); + assertThat(file.nullValueCounts()).hasSize(2); + assertThat(file.valueCounts()).hasSize(2); + assertThat(file.lowerBounds()).isEmpty(); + assertThat(file.upperBounds()).isEmpty(); + } + } + + @Test + public void testNoMetricsCollectionForParquet() throws IOException { + String tableLocation = temp.resolve("iceberg-table").toFile().toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.DEFAULT_WRITE_METRICS_MODE, "none"); + Table table = tables.create(SIMPLE_SCHEMA, spec, properties, tableLocation); + + List expectedRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset df = spark.createDataFrame(expectedRecords, SimpleRecord.class); + df.select("id", "data") + .coalesce(1) + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, "parquet") + .mode(SaveMode.Append) + .save(tableLocation); + + for (FileScanTask task : table.newScan().includeColumnStats().planFiles()) { + DataFile file = task.file(); + assertThat(file.nullValueCounts()).isEmpty(); + assertThat(file.valueCounts()).isEmpty(); + assertThat(file.lowerBounds()).isEmpty(); + assertThat(file.upperBounds()).isEmpty(); + } + } + + @Test + public void testCustomMetricCollectionForParquet() throws IOException { + String tableLocation = temp.resolve("iceberg-table").toFile().toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.DEFAULT_WRITE_METRICS_MODE, "counts"); + properties.put("write.metadata.metrics.column.id", "full"); + Table table = tables.create(SIMPLE_SCHEMA, spec, properties, tableLocation); + + List expectedRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset df = spark.createDataFrame(expectedRecords, SimpleRecord.class); + df.select("id", "data") + .coalesce(1) + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, "parquet") + .mode(SaveMode.Append) + .save(tableLocation); + + Schema schema = table.schema(); + Types.NestedField id = schema.findField("id"); + for (FileScanTask task : table.newScan().includeColumnStats().planFiles()) { + DataFile file = task.file(); + assertThat(file.nullValueCounts()).hasSize(2); + assertThat(file.valueCounts()).hasSize(2); + assertThat(file.lowerBounds()).hasSize(1).containsKey(id.fieldId()); + assertThat(file.upperBounds()).hasSize(1).containsKey(id.fieldId()); + } + } + + @Test + public void testBadCustomMetricCollectionForParquet() throws IOException { + String tableLocation = temp.resolve("iceberg-table").toFile().toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.DEFAULT_WRITE_METRICS_MODE, "counts"); + properties.put("write.metadata.metrics.column.ids", "full"); + + assertThatThrownBy(() -> tables.create(SIMPLE_SCHEMA, spec, properties, tableLocation)) + .isInstanceOf(ValidationException.class) + .hasMessageStartingWith( + "Invalid metrics config, could not find column ids from table prop write.metadata.metrics.column.ids in schema table"); + } + + @Test + public void testCustomMetricCollectionForNestedParquet() throws IOException { + String tableLocation = temp.resolve("iceberg-table").toFile().toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(COMPLEX_SCHEMA).identity("strCol").build(); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.DEFAULT_WRITE_METRICS_MODE, "none"); + properties.put("write.metadata.metrics.column.longCol", "counts"); + properties.put("write.metadata.metrics.column.record.id", "full"); + properties.put("write.metadata.metrics.column.record.data", "truncate(2)"); + Table table = tables.create(COMPLEX_SCHEMA, spec, properties, tableLocation); + + Iterable rows = RandomData.generateSpark(COMPLEX_SCHEMA, 10, 0); + JavaRDD rdd = sc.parallelize(Lists.newArrayList(rows)); + Dataset df = + spark.internalCreateDataFrame(JavaRDD.toRDD(rdd), convert(COMPLEX_SCHEMA), false); + + df.coalesce(1) + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, "parquet") + .mode(SaveMode.Append) + .save(tableLocation); + + Schema schema = table.schema(); + Types.NestedField longCol = schema.findField("longCol"); + Types.NestedField recordId = schema.findField("record.id"); + Types.NestedField recordData = schema.findField("record.data"); + for (FileScanTask task : table.newScan().includeColumnStats().planFiles()) { + DataFile file = task.file(); + + Map nullValueCounts = file.nullValueCounts(); + assertThat(nullValueCounts) + .hasSize(3) + .containsKey(longCol.fieldId()) + .containsKey(recordId.fieldId()) + .containsKey(recordData.fieldId()); + + Map valueCounts = file.valueCounts(); + assertThat(valueCounts) + .hasSize(3) + .containsKey(longCol.fieldId()) + .containsKey(recordId.fieldId()) + .containsKey(recordData.fieldId()); + + Map lowerBounds = file.lowerBounds(); + assertThat(lowerBounds).hasSize(2).containsKey(recordId.fieldId()); + + ByteBuffer recordDataLowerBound = lowerBounds.get(recordData.fieldId()); + assertThat(ByteBuffers.toByteArray(recordDataLowerBound)).hasSize(2); + + Map upperBounds = file.upperBounds(); + assertThat(upperBounds).hasSize(2).containsKey(recordId.fieldId()); + + ByteBuffer recordDataUpperBound = upperBounds.get(recordData.fieldId()); + assertThat(ByteBuffers.toByteArray(recordDataUpperBound)).hasSize(2); + } + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/ThreeColumnRecord.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/ThreeColumnRecord.java new file mode 100644 index 000000000000..554557df416c --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/source/ThreeColumnRecord.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Objects; + +public class ThreeColumnRecord { + private Integer c1; + private String c2; + private String c3; + + public ThreeColumnRecord() {} + + public ThreeColumnRecord(Integer c1, String c2, String c3) { + this.c1 = c1; + this.c2 = c2; + this.c3 = c3; + } + + public Integer getC1() { + return c1; + } + + public void setC1(Integer c1) { + this.c1 = c1; + } + + public String getC2() { + return c2; + } + + public void setC2(String c2) { + this.c2 = c2; + } + + public String getC3() { + return c3; + } + + public void setC3(String c3) { + this.c3 = c3; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ThreeColumnRecord that = (ThreeColumnRecord) o; + return Objects.equals(c1, that.c1) + && Objects.equals(c2, that.c2) + && Objects.equals(c3, that.c3); + } + + @Override + public int hashCode() { + return Objects.hash(c1, c2, c3); + } + + @Override + public String toString() { + return "ThreeColumnRecord{" + "c1=" + c1 + ", c2='" + c2 + '\'' + ", c3='" + c3 + '\'' + '}'; + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/PartitionedWritesTestBase.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/PartitionedWritesTestBase.java new file mode 100644 index 000000000000..97f8e6142dc5 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/PartitionedWritesTestBase.java @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.Arrays; +import java.util.List; +import org.apache.iceberg.Table; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.CatalogTestBase; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.functions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; + +public abstract class PartitionedWritesTestBase extends CatalogTestBase { + + @BeforeEach + public void createTables() { + sql( + "CREATE TABLE %s (id bigint, data string) USING iceberg PARTITIONED BY (truncate(id, 3))", + tableName); + sql("INSERT INTO %s VALUES (1, 'a'), (2, 'b'), (3, 'c')", tableName); + } + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testInsertAppend() { + assertThat(scalarSql("SELECT count(*) FROM %s", selectTarget())) + .as("Should have 5 rows after insert") + .isEqualTo(3L); + + sql("INSERT INTO %s VALUES (4, 'd'), (5, 'e')", commitTarget()); + + assertThat(scalarSql("SELECT count(*) FROM %s", selectTarget())) + .as("Should have 5 rows after insert") + .isEqualTo(5L); + + List expected = + ImmutableList.of(row(1L, "a"), row(2L, "b"), row(3L, "c"), row(4L, "d"), row(5L, "e")); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testInsertOverwrite() { + assertThat(scalarSql("SELECT count(*) FROM %s", selectTarget())) + .as("Should have 5 rows after insert") + .isEqualTo(3L); + + // 4 and 5 replace 3 in the partition (id - (id % 3)) = 3 + sql("INSERT OVERWRITE %s VALUES (4, 'd'), (5, 'e')", commitTarget()); + + assertThat(scalarSql("SELECT count(*) FROM %s", selectTarget())) + .as("Should have 4 rows after overwrite") + .isEqualTo(4L); + + List expected = + ImmutableList.of(row(1L, "a"), row(2L, "b"), row(4L, "d"), row(5L, "e")); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testDataFrameV2Append() throws NoSuchTableException { + assertThat(scalarSql("SELECT count(*) FROM %s", selectTarget())) + .as("Should have 3 rows") + .isEqualTo(3L); + + List data = ImmutableList.of(new SimpleRecord(4, "d"), new SimpleRecord(5, "e")); + Dataset ds = spark.createDataFrame(data, SimpleRecord.class); + + ds.writeTo(commitTarget()).append(); + + assertThat(scalarSql("SELECT count(*) FROM %s", selectTarget())) + .as("Should have 5 rows after insert") + .isEqualTo(5L); + + List expected = + ImmutableList.of(row(1L, "a"), row(2L, "b"), row(3L, "c"), row(4L, "d"), row(5L, "e")); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testDataFrameV2DynamicOverwrite() throws NoSuchTableException { + assertThat(scalarSql("SELECT count(*) FROM %s", selectTarget())) + .as("Should have 3 rows") + .isEqualTo(3L); + + List data = ImmutableList.of(new SimpleRecord(4, "d"), new SimpleRecord(5, "e")); + Dataset ds = spark.createDataFrame(data, SimpleRecord.class); + + ds.writeTo(commitTarget()).overwritePartitions(); + + assertThat(scalarSql("SELECT count(*) FROM %s", selectTarget())) + .as("Should have 4 rows after overwrite") + .isEqualTo(4L); + + List expected = + ImmutableList.of(row(1L, "a"), row(2L, "b"), row(4L, "d"), row(5L, "e")); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testDataFrameV2Overwrite() throws NoSuchTableException { + assertThat(scalarSql("SELECT count(*) FROM %s", selectTarget())) + .as("Should have 3 rows") + .isEqualTo(3L); + + List data = ImmutableList.of(new SimpleRecord(4, "d"), new SimpleRecord(5, "e")); + Dataset ds = spark.createDataFrame(data, SimpleRecord.class); + + ds.writeTo(commitTarget()).overwrite(functions.col("id").$less(3)); + + assertThat(scalarSql("SELECT count(*) FROM %s", selectTarget())) + .as("Should have 3 rows after overwrite") + .isEqualTo(3L); + + List expected = ImmutableList.of(row(3L, "c"), row(4L, "d"), row(5L, "e")); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testViewsReturnRecentResults() { + assertThat(scalarSql("SELECT count(*) FROM %s", selectTarget())) + .as("Should have 3 rows") + .isEqualTo(3L); + + Dataset query = spark.sql("SELECT * FROM " + commitTarget() + " WHERE id = 1"); + query.createOrReplaceTempView("tmp"); + + assertEquals( + "View should have expected rows", ImmutableList.of(row(1L, "a")), sql("SELECT * FROM tmp")); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", commitTarget()); + + assertEquals( + "View should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM tmp")); + } + + // Asserts whether the given table .partitions table has the expected rows. Note that the output + // row should have spec_id and it is sorted by spec_id and selectPartitionColumns. + protected void assertPartitionMetadata( + String tableName, List expected, String... selectPartitionColumns) { + String[] fullyQualifiedCols = + Arrays.stream(selectPartitionColumns).map(s -> "partition." + s).toArray(String[]::new); + Dataset actualPartitionRows = + spark + .read() + .format("iceberg") + .load(tableName + ".partitions") + .select("spec_id", fullyQualifiedCols) + .orderBy("spec_id", fullyQualifiedCols); + + assertEquals( + "There are 3 partitions, one with the original spec ID and two with the new one", + expected, + rowsToJava(actualPartitionRows.collectAsList())); + } + + @TestTemplate + public void testWriteWithOutputSpec() throws NoSuchTableException { + Table table = validationCatalog.loadTable(tableIdent); + + // Drop all records in table to have a fresh start. + table.newDelete().deleteFromRowFilter(Expressions.alwaysTrue()).commit(); + + final int originalSpecId = table.spec().specId(); + table.updateSpec().addField("data").commit(); + + // Refresh this when using SparkCatalog since otherwise the new spec would not be caught. + sql("REFRESH TABLE %s", tableName); + + // By default, we write to the current spec. + List data = ImmutableList.of(new SimpleRecord(10, "a")); + spark.createDataFrame(data, SimpleRecord.class).toDF().writeTo(tableName).append(); + + List expected = ImmutableList.of(row(10L, "a", table.spec().specId())); + assertEquals( + "Rows must match", + expected, + sql("SELECT id, data, _spec_id FROM %s WHERE id >= 10 ORDER BY id", tableName)); + + // Output spec ID should be respected when present. + data = ImmutableList.of(new SimpleRecord(11, "b"), new SimpleRecord(12, "c")); + spark + .createDataFrame(data, SimpleRecord.class) + .toDF() + .writeTo(tableName) + .option("output-spec-id", Integer.toString(originalSpecId)) + .append(); + + expected = + ImmutableList.of( + row(10L, "a", table.spec().specId()), + row(11L, "b", originalSpecId), + row(12L, "c", originalSpecId)); + assertEquals( + "Rows must match", + expected, + sql("SELECT id, data, _spec_id FROM %s WHERE id >= 10 ORDER BY id", tableName)); + + // Verify that the actual partitions are written with the correct spec ID. + // Two of the partitions should have the original spec ID and one should have the new one. + // TODO: WAP branch does not support reading partitions table, skip this check for now. + expected = + ImmutableList.of( + row(originalSpecId, 9L, null), + row(originalSpecId, 12L, null), + row(table.spec().specId(), 9L, "a")); + assertPartitionMetadata(tableName, expected, "id_trunc", "data"); + + // Even the default spec ID should be followed when present. + data = ImmutableList.of(new SimpleRecord(13, "d")); + spark + .createDataFrame(data, SimpleRecord.class) + .toDF() + .writeTo(tableName) + .option("output-spec-id", Integer.toString(table.spec().specId())) + .append(); + + expected = + ImmutableList.of( + row(10L, "a", table.spec().specId()), + row(11L, "b", originalSpecId), + row(12L, "c", originalSpecId), + row(13L, "d", table.spec().specId())); + assertEquals( + "Rows must match", + expected, + sql("SELECT id, data, _spec_id FROM %s WHERE id >= 10 ORDER BY id", tableName)); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java new file mode 100644 index 000000000000..6e09252704a1 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java @@ -0,0 +1,863 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.math.BigDecimal; +import java.sql.Date; +import java.sql.Timestamp; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import org.apache.iceberg.CatalogUtil; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.exceptions.AlreadyExistsException; +import org.apache.iceberg.hive.HiveCatalog; +import org.apache.iceberg.hive.TestHiveMetastore; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.CatalogTestBase; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.TestBase; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.execution.ExplainMode; +import org.apache.spark.sql.functions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.TestTemplate; + +public class TestAggregatePushDown extends CatalogTestBase { + + @BeforeAll + public static void startMetastoreAndSpark() { + TestBase.metastore = new TestHiveMetastore(); + metastore.start(); + TestBase.hiveConf = metastore.hiveConf(); + + TestBase.spark.close(); + + TestBase.spark = + SparkSession.builder() + .master("local[2]") + .config("spark.sql.iceberg.aggregate_pushdown", "true") + .enableHiveSupport() + .getOrCreate(); + + TestBase.catalog = + (HiveCatalog) + CatalogUtil.loadCatalog( + HiveCatalog.class.getName(), "hive", ImmutableMap.of(), hiveConf); + + try { + catalog.createNamespace(Namespace.of("default")); + } catch (AlreadyExistsException ignored) { + // the default namespace already exists. ignore the create error + } + } + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testDifferentDataTypesAggregatePushDownInPartitionedTable() { + testDifferentDataTypesAggregatePushDown(true); + } + + @TestTemplate + public void testDifferentDataTypesAggregatePushDownInNonPartitionedTable() { + testDifferentDataTypesAggregatePushDown(false); + } + + @SuppressWarnings("checkstyle:CyclomaticComplexity") + private void testDifferentDataTypesAggregatePushDown(boolean hasPartitionCol) { + String createTable; + if (hasPartitionCol) { + createTable = + "CREATE TABLE %s (id LONG, int_data INT, boolean_data BOOLEAN, float_data FLOAT, double_data DOUBLE, " + + "decimal_data DECIMAL(14, 2), binary_data binary) USING iceberg PARTITIONED BY (id)"; + } else { + createTable = + "CREATE TABLE %s (id LONG, int_data INT, boolean_data BOOLEAN, float_data FLOAT, double_data DOUBLE, " + + "decimal_data DECIMAL(14, 2), binary_data binary) USING iceberg"; + } + + sql(createTable, tableName); + sql( + "INSERT INTO TABLE %s VALUES " + + "(1, null, false, null, null, 11.11, X'1111')," + + " (1, null, true, 2.222, 2.222222, 22.22, X'2222')," + + " (2, 33, false, 3.333, 3.333333, 33.33, X'3333')," + + " (2, 44, true, null, 4.444444, 44.44, X'4444')," + + " (3, 55, false, 5.555, 5.555555, 55.55, X'5555')," + + " (3, null, true, null, 6.666666, 66.66, null) ", + tableName); + + String select = + "SELECT count(*), max(id), min(id), count(id), " + + "max(int_data), min(int_data), count(int_data), " + + "max(boolean_data), min(boolean_data), count(boolean_data), " + + "max(float_data), min(float_data), count(float_data), " + + "max(double_data), min(double_data), count(double_data), " + + "max(decimal_data), min(decimal_data), count(decimal_data), " + + "max(binary_data), min(binary_data), count(binary_data) FROM %s"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString().toLowerCase(Locale.ROOT); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("count(*)") + && explainString.contains("max(id)") + && explainString.contains("min(id)") + && explainString.contains("count(id)") + && explainString.contains("max(int_data)") + && explainString.contains("min(int_data)") + && explainString.contains("count(int_data)") + && explainString.contains("max(boolean_data)") + && explainString.contains("min(boolean_data)") + && explainString.contains("count(boolean_data)") + && explainString.contains("max(float_data)") + && explainString.contains("min(float_data)") + && explainString.contains("count(float_data)") + && explainString.contains("max(double_data)") + && explainString.contains("min(double_data)") + && explainString.contains("count(double_data)") + && explainString.contains("max(decimal_data)") + && explainString.contains("min(decimal_data)") + && explainString.contains("count(decimal_data)") + && explainString.contains("max(binary_data)") + && explainString.contains("min(binary_data)") + && explainString.contains("count(binary_data)")) { + explainContainsPushDownAggregates = true; + } + + assertThat(explainContainsPushDownAggregates) + .as("explain should contain the pushed down aggregates") + .isTrue(); + + List actual = sql(select, tableName); + List expected = Lists.newArrayList(); + expected.add( + new Object[] { + 6L, + 3L, + 1L, + 6L, + 55, + 33, + 3L, + true, + false, + 6L, + 5.555f, + 2.222f, + 3L, + 6.666666, + 2.222222, + 5L, + new BigDecimal("66.66"), + new BigDecimal("11.11"), + 6L, + new byte[] {85, 85}, + new byte[] {17, 17}, + 5L + }); + assertEquals("min/max/count push down", expected, actual); + } + + @TestTemplate + public void testDateAndTimestampWithPartition() { + sql( + "CREATE TABLE %s (id bigint, data string, d date, ts timestamp) USING iceberg PARTITIONED BY (id)", + tableName); + sql( + "INSERT INTO %s VALUES (1, '1', date('2021-11-10'), null)," + + "(1, '2', date('2021-11-11'), timestamp('2021-11-11 22:22:22')), " + + "(2, '3', date('2021-11-12'), timestamp('2021-11-12 22:22:22')), " + + "(2, '4', date('2021-11-13'), timestamp('2021-11-13 22:22:22')), " + + "(3, '5', null, timestamp('2021-11-14 22:22:22')), " + + "(3, '6', date('2021-11-14'), null)", + tableName); + String select = "SELECT max(d), min(d), count(d), max(ts), min(ts), count(ts) FROM %s"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString().toLowerCase(Locale.ROOT); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("max(d)") + && explainString.contains("min(d)") + && explainString.contains("count(d)") + && explainString.contains("max(ts)") + && explainString.contains("min(ts)") + && explainString.contains("count(ts)")) { + explainContainsPushDownAggregates = true; + } + + assertThat(explainContainsPushDownAggregates) + .as("explain should contain the pushed down aggregates") + .isTrue(); + + List actual = sql(select, tableName); + List expected = Lists.newArrayList(); + expected.add( + new Object[] { + Date.valueOf("2021-11-14"), + Date.valueOf("2021-11-10"), + 5L, + Timestamp.valueOf("2021-11-14 22:22:22.0"), + Timestamp.valueOf("2021-11-11 22:22:22.0"), + 4L + }); + assertEquals("min/max/count push down", expected, actual); + } + + @TestTemplate + public void testAggregateNotPushDownIfOneCantPushDown() { + sql("CREATE TABLE %s (id LONG, data DOUBLE) USING iceberg", tableName); + sql( + "INSERT INTO TABLE %s VALUES (1, 1111), (1, 2222), (2, 3333), (2, 4444), (3, 5555), (3, 6666) ", + tableName); + String select = "SELECT COUNT(data), SUM(data) FROM %s"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString().toLowerCase(Locale.ROOT); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("count(data)")) { + explainContainsPushDownAggregates = true; + } + + assertThat(explainContainsPushDownAggregates) + .as("explain should not contain the pushed down aggregates") + .isFalse(); + + List actual = sql(select, tableName); + List expected = Lists.newArrayList(); + expected.add(new Object[] {6L, 23331.0}); + assertEquals("expected and actual should equal", expected, actual); + } + + @TestTemplate + public void testAggregatePushDownWithMetricsMode() { + sql("CREATE TABLE %s (id LONG, data DOUBLE) USING iceberg", tableName); + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, TableProperties.DEFAULT_WRITE_METRICS_MODE, "none"); + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, TableProperties.METRICS_MODE_COLUMN_CONF_PREFIX + "id", "counts"); + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, TableProperties.METRICS_MODE_COLUMN_CONF_PREFIX + "data", "none"); + sql( + "INSERT INTO TABLE %s VALUES (1, 1111), (1, 2222), (2, 3333), (2, 4444), (3, 5555), (3, 6666)", + tableName); + + String select1 = "SELECT COUNT(data) FROM %s"; + + List explain1 = sql("EXPLAIN " + select1, tableName); + String explainString1 = explain1.get(0)[0].toString().toLowerCase(Locale.ROOT); + boolean explainContainsPushDownAggregates = false; + if (explainString1.contains("count(data)")) { + explainContainsPushDownAggregates = true; + } + + // count(data) is not pushed down because the metrics mode is `none` + assertThat(explainContainsPushDownAggregates) + .as("explain should not contain the pushed down aggregates") + .isFalse(); + + List actual1 = sql(select1, tableName); + List expected1 = Lists.newArrayList(); + expected1.add(new Object[] {6L}); + assertEquals("expected and actual should equal", expected1, actual1); + + String select2 = "SELECT COUNT(id) FROM %s"; + List explain2 = sql("EXPLAIN " + select2, tableName); + String explainString2 = explain2.get(0)[0].toString().toLowerCase(Locale.ROOT); + if (explainString2.contains("count(id)")) { + explainContainsPushDownAggregates = true; + } + + // count(id) is pushed down because the metrics mode is `counts` + assertThat(explainContainsPushDownAggregates) + .as("explain should contain the pushed down aggregates") + .isTrue(); + + List actual2 = sql(select2, tableName); + List expected2 = Lists.newArrayList(); + expected2.add(new Object[] {6L}); + assertEquals("expected and actual should equal", expected2, actual2); + + String select3 = "SELECT COUNT(id), MAX(id) FROM %s"; + explainContainsPushDownAggregates = false; + List explain3 = sql("EXPLAIN " + select3, tableName); + String explainString3 = explain3.get(0)[0].toString().toLowerCase(Locale.ROOT); + if (explainString3.contains("count(id)")) { + explainContainsPushDownAggregates = true; + } + + // COUNT(id), MAX(id) are not pushed down because MAX(id) is not pushed down (metrics mode is + // `counts`) + assertThat(explainContainsPushDownAggregates) + .as("explain should not contain the pushed down aggregates") + .isFalse(); + + List actual3 = sql(select3, tableName); + List expected3 = Lists.newArrayList(); + expected3.add(new Object[] {6L, 3L}); + assertEquals("expected and actual should equal", expected3, actual3); + } + + @TestTemplate + public void testAggregateNotPushDownForStringType() { + sql("CREATE TABLE %s (id LONG, data STRING) USING iceberg", tableName); + sql( + "INSERT INTO TABLE %s VALUES (1, '1111'), (1, '2222'), (2, '3333'), (2, '4444'), (3, '5555'), (3, '6666') ", + tableName); + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, TableProperties.DEFAULT_WRITE_METRICS_MODE, "truncate(16)"); + + String select1 = "SELECT MAX(id), MAX(data) FROM %s"; + + List explain1 = sql("EXPLAIN " + select1, tableName); + String explainString1 = explain1.get(0)[0].toString().toLowerCase(Locale.ROOT); + boolean explainContainsPushDownAggregates = false; + if (explainString1.contains("max(id)")) { + explainContainsPushDownAggregates = true; + } + + assertThat(explainContainsPushDownAggregates) + .as("explain should not contain the pushed down aggregates") + .isFalse(); + + List actual1 = sql(select1, tableName); + List expected1 = Lists.newArrayList(); + expected1.add(new Object[] {3L, "6666"}); + assertEquals("expected and actual should equal", expected1, actual1); + + String select2 = "SELECT COUNT(data) FROM %s"; + List explain2 = sql("EXPLAIN " + select2, tableName); + String explainString2 = explain2.get(0)[0].toString().toLowerCase(Locale.ROOT); + if (explainString2.contains("count(data)")) { + explainContainsPushDownAggregates = true; + } + + assertThat(explainContainsPushDownAggregates) + .as("explain should contain the pushed down aggregates") + .isTrue(); + + List actual2 = sql(select2, tableName); + List expected2 = Lists.newArrayList(); + expected2.add(new Object[] {6L}); + assertEquals("expected and actual should equal", expected2, actual2); + + explainContainsPushDownAggregates = false; + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, TableProperties.DEFAULT_WRITE_METRICS_MODE, "full"); + String select3 = "SELECT count(data), max(data) FROM %s"; + List explain3 = sql("EXPLAIN " + select3, tableName); + String explainString3 = explain3.get(0)[0].toString().toLowerCase(Locale.ROOT); + if (explainString3.contains("count(data)") && explainString3.contains("max(data)")) { + explainContainsPushDownAggregates = true; + } + + assertThat(explainContainsPushDownAggregates) + .as("explain should contain the pushed down aggregates") + .isTrue(); + + List actual3 = sql(select3, tableName); + List expected3 = Lists.newArrayList(); + expected3.add(new Object[] {6L, "6666"}); + assertEquals("expected and actual should equal", expected3, actual3); + } + + @TestTemplate + public void testAggregatePushDownWithDataFilter() { + testAggregatePushDownWithFilter(false); + } + + @TestTemplate + public void testAggregatePushDownWithPartitionFilter() { + testAggregatePushDownWithFilter(true); + } + + private void testAggregatePushDownWithFilter(boolean partitionFilerOnly) { + String createTable; + if (!partitionFilerOnly) { + createTable = "CREATE TABLE %s (id LONG, data INT) USING iceberg"; + } else { + createTable = "CREATE TABLE %s (id LONG, data INT) USING iceberg PARTITIONED BY (id)"; + } + + sql(createTable, tableName); + + sql( + "INSERT INTO TABLE %s VALUES" + + " (1, 11)," + + " (1, 22)," + + " (2, 33)," + + " (2, 44)," + + " (3, 55)," + + " (3, 66) ", + tableName); + + String select = "SELECT MIN(data) FROM %s WHERE id > 1"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString().toLowerCase(Locale.ROOT); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("min(data)")) { + explainContainsPushDownAggregates = true; + } + + if (!partitionFilerOnly) { + // Filters are not completely pushed down, we can't push down aggregates + assertThat(explainContainsPushDownAggregates) + .as("explain should not contain the pushed down aggregates") + .isFalse(); + } else { + // Filters are not completely pushed down, we can push down aggregates + assertThat(explainContainsPushDownAggregates) + .as("explain should contain the pushed down aggregates") + .isTrue(); + } + + List actual = sql(select, tableName); + List expected = Lists.newArrayList(); + expected.add(new Object[] {33}); + assertEquals("expected and actual should equal", expected, actual); + } + + @TestTemplate + public void testAggregateWithComplexType() { + sql("CREATE TABLE %s (id INT, complex STRUCT) USING iceberg", tableName); + sql( + "INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", 3, \"c2\", \"v1\"))," + + "(2, named_struct(\"c1\", 2, \"c2\", \"v2\")), (3, null)", + tableName); + String select1 = "SELECT count(complex), count(id) FROM %s"; + List explain = sql("EXPLAIN " + select1, tableName); + String explainString = explain.get(0)[0].toString().toLowerCase(Locale.ROOT); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("count(complex)")) { + explainContainsPushDownAggregates = true; + } + + assertThat(explainContainsPushDownAggregates) + .as("count not pushed down for complex types") + .isFalse(); + + List actual = sql(select1, tableName); + List expected = Lists.newArrayList(); + expected.add(new Object[] {2L, 3L}); + assertEquals("count not push down", actual, expected); + + String select2 = "SELECT max(complex) FROM %s"; + explain = sql("EXPLAIN " + select2, tableName); + explainString = explain.get(0)[0].toString().toLowerCase(Locale.ROOT); + explainContainsPushDownAggregates = false; + if (explainString.contains("max(complex)")) { + explainContainsPushDownAggregates = true; + } + + assertThat(explainContainsPushDownAggregates) + .as("max not pushed down for complex types") + .isFalse(); + } + + @TestTemplate + public void testAggregationPushdownStructInteger() { + sql("CREATE TABLE %s (id BIGINT, struct_with_int STRUCT) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", NULL))", tableName); + sql("INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 2))", tableName); + sql("INSERT INTO TABLE %s VALUES (3, named_struct(\"c1\", 3))", tableName); + + String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s"; + String aggField = "struct_with_int.c1"; + assertAggregates(sql(query, aggField, aggField, aggField, tableName), 2L, 3L, 2L); + assertExplainContains( + sql("EXPLAIN " + query, aggField, aggField, aggField, tableName), + "count(struct_with_int.c1)", + "max(struct_with_int.c1)", + "min(struct_with_int.c1)"); + } + + @TestTemplate + public void testAggregationPushdownNestedStruct() { + sql( + "CREATE TABLE %s (id BIGINT, struct_with_int STRUCT>>>) USING iceberg", + tableName); + sql( + "INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", named_struct(\"c2\", named_struct(\"c3\", named_struct(\"c4\", NULL)))))", + tableName); + sql( + "INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", named_struct(\"c2\", named_struct(\"c3\", named_struct(\"c4\", 2)))))", + tableName); + sql( + "INSERT INTO TABLE %s VALUES (3, named_struct(\"c1\", named_struct(\"c2\", named_struct(\"c3\", named_struct(\"c4\", 3)))))", + tableName); + + String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s"; + String aggField = "struct_with_int.c1.c2.c3.c4"; + + assertAggregates(sql(query, aggField, aggField, aggField, tableName), 2L, 3L, 2L); + + assertExplainContains( + sql("EXPLAIN " + query, aggField, aggField, aggField, tableName), + "count(struct_with_int.c1.c2.c3.c4)", + "max(struct_with_int.c1.c2.c3.c4)", + "min(struct_with_int.c1.c2.c3.c4)"); + } + + @TestTemplate + public void testAggregationPushdownStructTimestamp() { + sql( + "CREATE TABLE %s (id BIGINT, struct_with_ts STRUCT) USING iceberg", + tableName); + sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", NULL))", tableName); + sql( + "INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", timestamp('2023-01-30T22:22:22Z')))", + tableName); + sql( + "INSERT INTO TABLE %s VALUES (3, named_struct(\"c1\", timestamp('2023-01-30T22:23:23Z')))", + tableName); + + String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s"; + String aggField = "struct_with_ts.c1"; + + assertAggregates( + sql(query, aggField, aggField, aggField, tableName), + 2L, + new Timestamp(1675117403000L), + new Timestamp(1675117342000L)); + + assertExplainContains( + sql("EXPLAIN " + query, aggField, aggField, aggField, tableName), + "count(struct_with_ts.c1)", + "max(struct_with_ts.c1)", + "min(struct_with_ts.c1)"); + } + + @TestTemplate + public void testAggregationPushdownOnBucketedColumn() { + sql( + "CREATE TABLE %s (id BIGINT, struct_with_int STRUCT) USING iceberg PARTITIONED BY (bucket(8, id))", + tableName); + + sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", NULL))", tableName); + sql("INSERT INTO TABLE %s VALUES (null, named_struct(\"c1\", 2))", tableName); + sql("INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 3))", tableName); + + String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s"; + String aggField = "id"; + assertAggregates(sql(query, aggField, aggField, aggField, tableName), 2L, 2L, 1L); + assertExplainContains( + sql("EXPLAIN " + query, aggField, aggField, aggField, tableName), + "count(id)", + "max(id)", + "min(id)"); + } + + private void assertAggregates( + List actual, Object expectedCount, Object expectedMax, Object expectedMin) { + Object actualCount = actual.get(0)[0]; + Object actualMax = actual.get(0)[1]; + Object actualMin = actual.get(0)[2]; + + assertThat(actualCount).as("Expected and actual count should equal").isEqualTo(expectedCount); + assertThat(actualMax).as("Expected and actual max should equal").isEqualTo(expectedMax); + assertThat(actualMin).as("Expected and actual min should equal").isEqualTo(expectedMin); + } + + private void assertExplainContains(List explain, String... expectedFragments) { + String explainString = explain.get(0)[0].toString().toLowerCase(Locale.ROOT); + Arrays.stream(expectedFragments) + .forEach( + fragment -> + assertThat(explainString) + .as("Expected to find plan fragment in explain plan") + .contains(fragment)); + } + + @TestTemplate + public void testAggregatePushDownInDeleteCopyOnWrite() { + sql("CREATE TABLE %s (id LONG, data INT) USING iceberg", tableName); + sql( + "INSERT INTO TABLE %s VALUES (1, 1111), (1, 2222), (2, 3333), (2, 4444), (3, 5555), (3, 6666) ", + tableName); + sql("DELETE FROM %s WHERE data = 1111", tableName); + String select = "SELECT max(data), min(data), count(data) FROM %s"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString().toLowerCase(Locale.ROOT); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("max(data)") + && explainString.contains("min(data)") + && explainString.contains("count(data)")) { + explainContainsPushDownAggregates = true; + } + + assertThat(explainContainsPushDownAggregates) + .as("min/max/count pushed down for deleted") + .isTrue(); + + List actual = sql(select, tableName); + List expected = Lists.newArrayList(); + expected.add(new Object[] {6666, 2222, 5L}); + assertEquals("min/max/count push down", expected, actual); + } + + @TestTemplate + public void testAggregatePushDownForTimeTravel() { + sql("CREATE TABLE %s (id LONG, data INT) USING iceberg", tableName); + sql( + "INSERT INTO TABLE %s VALUES (1, 1111), (1, 2222), (2, 3333), (2, 4444), (3, 5555), (3, 6666) ", + tableName); + + long snapshotId = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); + List expected1 = sql("SELECT count(id) FROM %s", tableName); + + sql("INSERT INTO %s VALUES (4, 7777), (5, 8888)", tableName); + List expected2 = sql("SELECT count(id) FROM %s", tableName); + + List explain1 = + sql("EXPLAIN SELECT count(id) FROM %s VERSION AS OF %s", tableName, snapshotId); + String explainString1 = explain1.get(0)[0].toString().toLowerCase(Locale.ROOT); + boolean explainContainsPushDownAggregates1 = false; + if (explainString1.contains("count(id)")) { + explainContainsPushDownAggregates1 = true; + } + assertThat(explainContainsPushDownAggregates1).as("count pushed down").isTrue(); + + List actual1 = + sql("SELECT count(id) FROM %s VERSION AS OF %s", tableName, snapshotId); + assertEquals("count push down", expected1, actual1); + + List explain2 = sql("EXPLAIN SELECT count(id) FROM %s", tableName); + String explainString2 = explain2.get(0)[0].toString().toLowerCase(Locale.ROOT); + boolean explainContainsPushDownAggregates2 = false; + if (explainString2.contains("count(id)")) { + explainContainsPushDownAggregates2 = true; + } + + assertThat(explainContainsPushDownAggregates2).as("count pushed down").isTrue(); + + List actual2 = sql("SELECT count(id) FROM %s", tableName); + assertEquals("count push down", expected2, actual2); + } + + @TestTemplate + public void testAllNull() { + sql("CREATE TABLE %s (id int, data int) USING iceberg PARTITIONED BY (id)", tableName); + sql( + "INSERT INTO %s VALUES (1, null)," + + "(1, null), " + + "(2, null), " + + "(2, null), " + + "(3, null), " + + "(3, null)", + tableName); + String select = "SELECT count(*), max(data), min(data), count(data) FROM %s"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString().toLowerCase(Locale.ROOT); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("max(data)") + && explainString.contains("min(data)") + && explainString.contains("count(data)")) { + explainContainsPushDownAggregates = true; + } + + assertThat(explainContainsPushDownAggregates) + .as("explain should contain the pushed down aggregates") + .isTrue(); + + List actual = sql(select, tableName); + List expected = Lists.newArrayList(); + expected.add(new Object[] {6L, null, null, 0L}); + assertEquals("min/max/count push down", expected, actual); + } + + @TestTemplate + public void testAllNaN() { + sql("CREATE TABLE %s (id int, data float) USING iceberg PARTITIONED BY (id)", tableName); + sql( + "INSERT INTO %s VALUES (1, float('nan'))," + + "(1, float('nan')), " + + "(2, float('nan')), " + + "(2, float('nan')), " + + "(3, float('nan')), " + + "(3, float('nan'))", + tableName); + String select = "SELECT count(*), max(data), min(data), count(data) FROM %s"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString().toLowerCase(Locale.ROOT); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("max(data)") + || explainString.contains("min(data)") + || explainString.contains("count(data)")) { + explainContainsPushDownAggregates = true; + } + + assertThat(explainContainsPushDownAggregates) + .as("explain should not contain the pushed down aggregates") + .isFalse(); + + List actual = sql(select, tableName); + List expected = Lists.newArrayList(); + expected.add(new Object[] {6L, Float.NaN, Float.NaN, 6L}); + assertEquals("expected and actual should equal", expected, actual); + } + + @TestTemplate + public void testNaN() { + sql("CREATE TABLE %s (id int, data float) USING iceberg PARTITIONED BY (id)", tableName); + sql( + "INSERT INTO %s VALUES (1, float('nan'))," + + "(1, float('nan')), " + + "(2, 2), " + + "(2, float('nan')), " + + "(3, float('nan')), " + + "(3, 1)", + tableName); + String select = "SELECT count(*), max(data), min(data), count(data) FROM %s"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString().toLowerCase(Locale.ROOT); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("max(data)") + || explainString.contains("min(data)") + || explainString.contains("count(data)")) { + explainContainsPushDownAggregates = true; + } + + assertThat(explainContainsPushDownAggregates) + .as("explain should not contain the pushed down aggregates") + .isFalse(); + + List actual = sql(select, tableName); + List expected = Lists.newArrayList(); + expected.add(new Object[] {6L, Float.NaN, 1.0F, 6L}); + assertEquals("expected and actual should equal", expected, actual); + } + + @TestTemplate + public void testInfinity() { + sql( + "CREATE TABLE %s (id int, data1 float, data2 double, data3 double) USING iceberg PARTITIONED BY (id)", + tableName); + sql( + "INSERT INTO %s VALUES (1, float('-infinity'), double('infinity'), 1.23), " + + "(1, float('-infinity'), double('infinity'), -1.23), " + + "(1, float('-infinity'), double('infinity'), double('infinity')), " + + "(1, float('-infinity'), double('infinity'), 2.23), " + + "(1, float('-infinity'), double('infinity'), double('-infinity')), " + + "(1, float('-infinity'), double('infinity'), -2.23)", + tableName); + String select = + "SELECT count(*), max(data1), min(data1), count(data1), max(data2), min(data2), count(data2), max(data3), min(data3), count(data3) FROM %s"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString().toLowerCase(Locale.ROOT); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("max(data1)") + && explainString.contains("min(data1)") + && explainString.contains("count(data1)") + && explainString.contains("max(data2)") + && explainString.contains("min(data2)") + && explainString.contains("count(data2)") + && explainString.contains("max(data3)") + && explainString.contains("min(data3)") + && explainString.contains("count(data3)")) { + explainContainsPushDownAggregates = true; + } + + assertThat(explainContainsPushDownAggregates) + .as("explain should contain the pushed down aggregates") + .isTrue(); + + List actual = sql(select, tableName); + List expected = Lists.newArrayList(); + expected.add( + new Object[] { + 6L, + Float.NEGATIVE_INFINITY, + Float.NEGATIVE_INFINITY, + 6L, + Double.POSITIVE_INFINITY, + Double.POSITIVE_INFINITY, + 6L, + Double.POSITIVE_INFINITY, + Double.NEGATIVE_INFINITY, + 6L + }); + assertEquals("min/max/count push down", expected, actual); + } + + @TestTemplate + public void testAggregatePushDownForIncrementalScan() { + sql("CREATE TABLE %s (id LONG, data INT) USING iceberg", tableName); + sql( + "INSERT INTO TABLE %s VALUES (1, 1111), (1, 2222), (2, 3333), (2, 4444), (3, 5555), (3, 6666) ", + tableName); + long snapshotId1 = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); + sql("INSERT INTO %s VALUES (4, 7777), (5, 8888)", tableName); + long snapshotId2 = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); + sql("INSERT INTO %s VALUES (6, -7777), (7, 8888)", tableName); + long snapshotId3 = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); + sql("INSERT INTO %s VALUES (8, 7777), (9, 9999)", tableName); + + Dataset pushdownDs = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.START_SNAPSHOT_ID, snapshotId2) + .option(SparkReadOptions.END_SNAPSHOT_ID, snapshotId3) + .load(tableName) + .agg(functions.min("data"), functions.max("data"), functions.count("data")); + String explain1 = pushdownDs.queryExecution().explainString(ExplainMode.fromString("simple")); + assertThat(explain1).contains("LocalTableScan", "min(data)", "max(data)", "count(data)"); + + List expected1 = Lists.newArrayList(); + expected1.add(new Object[] {-7777, 8888, 2L}); + assertEquals("min/max/count push down", expected1, rowsToJava(pushdownDs.collectAsList())); + + Dataset unboundedPushdownDs = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.START_SNAPSHOT_ID, snapshotId1) + .load(tableName) + .agg(functions.min("data"), functions.max("data"), functions.count("data")); + String explain2 = + unboundedPushdownDs.queryExecution().explainString(ExplainMode.fromString("simple")); + assertThat(explain2).contains("LocalTableScan", "min(data)", "max(data)", "count(data)"); + + List expected2 = Lists.newArrayList(); + expected2.add(new Object[] {-7777, 9999, 6L}); + assertEquals( + "min/max/count push down", expected2, rowsToJava(unboundedPushdownDs.collectAsList())); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestAlterTable.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestAlterTable.java new file mode 100644 index 000000000000..7c98888f1667 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestAlterTable.java @@ -0,0 +1,318 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assumptions.assumeThat; + +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.hadoop.HadoopCatalog; +import org.apache.iceberg.spark.CatalogTestBase; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.types.Types.NestedField; +import org.apache.spark.SparkException; +import org.apache.spark.sql.AnalysisException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestAlterTable extends CatalogTestBase { + private final TableIdentifier renamedIdent = + TableIdentifier.of(Namespace.of("default"), "table2"); + + @BeforeEach + public void createTable() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + } + + @AfterEach + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS %s2", tableName); + } + + @TestTemplate + public void testAddColumnNotNull() { + assertThatThrownBy(() -> sql("ALTER TABLE %s ADD COLUMN c3 INT NOT NULL", tableName)) + .isInstanceOf(SparkException.class) + .hasMessage( + "Unsupported table change: Incompatible change: cannot add required column: c3"); + } + + @TestTemplate + public void testAddColumn() { + sql( + "ALTER TABLE %s ADD COLUMN point struct AFTER id", + tableName); + + Types.StructType expectedSchema = + Types.StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional( + 3, + "point", + Types.StructType.of( + NestedField.required(4, "x", Types.DoubleType.get()), + NestedField.required(5, "y", Types.DoubleType.get()))), + NestedField.optional(2, "data", Types.StringType.get())); + + assertThat(validationCatalog.loadTable(tableIdent).schema().asStruct()) + .as("Schema should match expected") + .isEqualTo(expectedSchema); + + sql("ALTER TABLE %s ADD COLUMN point.z double COMMENT 'May be null' FIRST", tableName); + + Types.StructType expectedSchema2 = + Types.StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional( + 3, + "point", + Types.StructType.of( + NestedField.optional(6, "z", Types.DoubleType.get(), "May be null"), + NestedField.required(4, "x", Types.DoubleType.get()), + NestedField.required(5, "y", Types.DoubleType.get()))), + NestedField.optional(2, "data", Types.StringType.get())); + + assertThat(validationCatalog.loadTable(tableIdent).schema().asStruct()) + .as("Schema should match expected") + .isEqualTo(expectedSchema2); + } + + @TestTemplate + public void testAddColumnWithArray() { + sql("ALTER TABLE %s ADD COLUMN data2 array>", tableName); + // use the implicit column name 'element' to access member of array and add column d to struct. + sql("ALTER TABLE %s ADD COLUMN data2.element.d int", tableName); + Types.StructType expectedSchema = + Types.StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get()), + NestedField.optional( + 3, + "data2", + Types.ListType.ofOptional( + 4, + Types.StructType.of( + NestedField.optional(5, "a", Types.IntegerType.get()), + NestedField.optional(6, "b", Types.IntegerType.get()), + NestedField.optional(7, "c", Types.IntegerType.get()), + NestedField.optional(8, "d", Types.IntegerType.get()))))); + assertThat(validationCatalog.loadTable(tableIdent).schema().asStruct()) + .as("Schema should match expected") + .isEqualTo(expectedSchema); + } + + @TestTemplate + public void testAddColumnWithMap() { + sql("ALTER TABLE %s ADD COLUMN data2 map, struct>", tableName); + // use the implicit column name 'key' and 'value' to access member of map. + // add column to value struct column + sql("ALTER TABLE %s ADD COLUMN data2.value.c int", tableName); + Types.StructType expectedSchema = + Types.StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get()), + NestedField.optional( + 3, + "data2", + Types.MapType.ofOptional( + 4, + 5, + Types.StructType.of(NestedField.optional(6, "x", Types.IntegerType.get())), + Types.StructType.of( + NestedField.optional(7, "a", Types.IntegerType.get()), + NestedField.optional(8, "b", Types.IntegerType.get()), + NestedField.optional(9, "c", Types.IntegerType.get()))))); + assertThat(validationCatalog.loadTable(tableIdent).schema().asStruct()) + .as("Schema should match expected") + .isEqualTo(expectedSchema); + + // should not allow changing map key column + assertThatThrownBy(() -> sql("ALTER TABLE %s ADD COLUMN data2.key.y int", tableName)) + .isInstanceOf(SparkException.class) + .hasMessageStartingWith("Unsupported table change: Cannot add fields to map keys:"); + } + + @TestTemplate + public void testDropColumn() { + sql("ALTER TABLE %s DROP COLUMN data", tableName); + + Types.StructType expectedSchema = + Types.StructType.of(NestedField.required(1, "id", Types.LongType.get())); + + assertThat(validationCatalog.loadTable(tableIdent).schema().asStruct()) + .as("Schema should match expected") + .isEqualTo(expectedSchema); + } + + @TestTemplate + public void testRenameColumn() { + sql("ALTER TABLE %s RENAME COLUMN id TO row_id", tableName); + + Types.StructType expectedSchema = + Types.StructType.of( + NestedField.required(1, "row_id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get())); + + assertThat(validationCatalog.loadTable(tableIdent).schema().asStruct()) + .as("Schema should match expected") + .isEqualTo(expectedSchema); + } + + @TestTemplate + public void testAlterColumnComment() { + sql("ALTER TABLE %s ALTER COLUMN id COMMENT 'Record id'", tableName); + + Types.StructType expectedSchema = + Types.StructType.of( + NestedField.required(1, "id", Types.LongType.get(), "Record id"), + NestedField.optional(2, "data", Types.StringType.get())); + + assertThat(validationCatalog.loadTable(tableIdent).schema().asStruct()) + .as("Schema should match expected") + .isEqualTo(expectedSchema); + } + + @TestTemplate + public void testAlterColumnType() { + sql("ALTER TABLE %s ADD COLUMN count int", tableName); + sql("ALTER TABLE %s ALTER COLUMN count TYPE bigint", tableName); + + Types.StructType expectedSchema = + Types.StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get()), + NestedField.optional(3, "count", Types.LongType.get())); + + assertThat(validationCatalog.loadTable(tableIdent).schema().asStruct()) + .as("Schema should match expected") + .isEqualTo(expectedSchema); + } + + @TestTemplate + public void testAlterColumnDropNotNull() { + sql("ALTER TABLE %s ALTER COLUMN id DROP NOT NULL", tableName); + + Types.StructType expectedSchema = + Types.StructType.of( + NestedField.optional(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get())); + + assertThat(validationCatalog.loadTable(tableIdent).schema().asStruct()) + .as("Schema should match expected") + .isEqualTo(expectedSchema); + } + + @TestTemplate + public void testAlterColumnSetNotNull() { + // no-op changes are allowed + sql("ALTER TABLE %s ALTER COLUMN id SET NOT NULL", tableName); + + Types.StructType expectedSchema = + Types.StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get())); + + assertThat(validationCatalog.loadTable(tableIdent).schema().asStruct()) + .as("Schema should match expected") + .isEqualTo(expectedSchema); + + assertThatThrownBy(() -> sql("ALTER TABLE %s ALTER COLUMN data SET NOT NULL", tableName)) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith("Cannot change nullable column to non-nullable: data"); + } + + @TestTemplate + public void testAlterColumnPositionAfter() { + sql("ALTER TABLE %s ADD COLUMN count int", tableName); + sql("ALTER TABLE %s ALTER COLUMN count AFTER id", tableName); + + Types.StructType expectedSchema = + Types.StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(3, "count", Types.IntegerType.get()), + NestedField.optional(2, "data", Types.StringType.get())); + + assertThat(validationCatalog.loadTable(tableIdent).schema().asStruct()) + .as("Schema should match expected") + .isEqualTo(expectedSchema); + } + + @TestTemplate + public void testAlterColumnPositionFirst() { + sql("ALTER TABLE %s ADD COLUMN count int", tableName); + sql("ALTER TABLE %s ALTER COLUMN count FIRST", tableName); + + Types.StructType expectedSchema = + Types.StructType.of( + NestedField.optional(3, "count", Types.IntegerType.get()), + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get())); + + assertThat(validationCatalog.loadTable(tableIdent).schema().asStruct()) + .as("Schema should match expected") + .isEqualTo(expectedSchema); + } + + @TestTemplate + public void testTableRename() { + assumeThat(validationCatalog) + .as("Hadoop catalog does not support rename") + .isNotInstanceOf(HadoopCatalog.class); + + assertThat(validationCatalog.tableExists(tableIdent)).as("Initial name should exist").isTrue(); + assertThat(validationCatalog.tableExists(renamedIdent)) + .as("New name should not exist") + .isFalse(); + + sql("ALTER TABLE %s RENAME TO %s2", tableName, tableName); + + assertThat(validationCatalog.tableExists(tableIdent)) + .as("Initial name should not exist") + .isFalse(); + assertThat(validationCatalog.tableExists(renamedIdent)).as("New name should exist").isTrue(); + } + + @TestTemplate + public void testSetTableProperties() { + sql("ALTER TABLE %s SET TBLPROPERTIES ('prop'='value')", tableName); + + assertThat(validationCatalog.loadTable(tableIdent).properties().get("prop")) + .as("Should have the new table property") + .isEqualTo("value"); + + sql("ALTER TABLE %s UNSET TBLPROPERTIES ('prop')", tableName); + + assertThat(validationCatalog.loadTable(tableIdent).properties().get("prop")) + .as("Should not have the removed table property") + .isNull(); + + String[] reservedProperties = new String[] {"sort-order", "identifier-fields"}; + for (String reservedProp : reservedProperties) { + assertThatThrownBy( + () -> sql("ALTER TABLE %s SET TBLPROPERTIES ('%s'='value')", tableName, reservedProp)) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessageStartingWith( + "Cannot specify the '%s' because it's a reserved table property", reservedProp); + } + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestCreateTable.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestCreateTable.java new file mode 100644 index 000000000000..11d4cfebfea6 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestCreateTable.java @@ -0,0 +1,462 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assumptions.assumeThat; + +import java.io.File; +import java.nio.file.Files; +import java.util.UUID; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.hadoop.HadoopCatalog; +import org.apache.iceberg.spark.CatalogTestBase; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.types.Types.NestedField; +import org.apache.iceberg.types.Types.StructType; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestCreateTable extends CatalogTestBase { + + @AfterEach + public void dropTestTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testTransformIgnoreCase() { + assertThat(validationCatalog.tableExists(tableIdent)) + .as("Table should not already exist") + .isFalse(); + sql( + "CREATE TABLE IF NOT EXISTS %s (id BIGINT NOT NULL, ts timestamp) " + + "USING iceberg partitioned by (HOURS(ts))", + tableName); + assertThat(validationCatalog.tableExists(tableIdent)).as("Table should already exist").isTrue(); + sql( + "CREATE TABLE IF NOT EXISTS %s (id BIGINT NOT NULL, ts timestamp) " + + "USING iceberg partitioned by (hours(ts))", + tableName); + assertThat(validationCatalog.tableExists(tableIdent)).as("Table should already exist").isTrue(); + } + + @TestTemplate + public void testTransformSingularForm() { + assertThat(validationCatalog.tableExists(tableIdent)) + .as("Table should not already exist") + .isFalse(); + sql( + "CREATE TABLE IF NOT EXISTS %s (id BIGINT NOT NULL, ts timestamp) " + + "USING iceberg partitioned by (hour(ts))", + tableName); + assertThat(validationCatalog.tableExists(tableIdent)).as("Table should exist").isTrue(); + } + + @TestTemplate + public void testTransformPluralForm() { + assertThat(validationCatalog.tableExists(tableIdent)) + .as("Table should not already exist") + .isFalse(); + sql( + "CREATE TABLE IF NOT EXISTS %s (id BIGINT NOT NULL, ts timestamp) " + + "USING iceberg partitioned by (hours(ts))", + tableName); + assertThat(validationCatalog.tableExists(tableIdent)).as("Table should exist").isTrue(); + } + + @TestTemplate + public void testCreateTable() { + assertThat(validationCatalog.tableExists(tableIdent)) + .as("Table should not already exist") + .isFalse(); + + sql("CREATE TABLE %s (id BIGINT NOT NULL, data STRING) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table).as("Should load the new table").isNotNull(); + + StructType expectedSchema = + StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get())); + assertThat(table.schema().asStruct()) + .as("Should have the expected schema") + .isEqualTo(expectedSchema); + assertThat(table.spec().fields()).as("Should not be partitioned").hasSize(0); + assertThat(table.properties().get(TableProperties.DEFAULT_FILE_FORMAT)) + .as("Should not have the default format set") + .isNull(); + } + + @TestTemplate + public void testCreateTablePartitionedByUUID() { + assertThat(validationCatalog.tableExists(tableIdent)).isFalse(); + Schema schema = new Schema(1, Types.NestedField.optional(1, "uuid", Types.UUIDType.get())); + PartitionSpec spec = PartitionSpec.builderFor(schema).bucket("uuid", 16).build(); + validationCatalog.createTable(tableIdent, schema, spec); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table).isNotNull(); + + StructType expectedSchema = + StructType.of(Types.NestedField.optional(1, "uuid", Types.UUIDType.get())); + assertThat(table.schema().asStruct()).isEqualTo(expectedSchema); + assertThat(table.spec().fields()).hasSize(1); + + String uuid = UUID.randomUUID().toString(); + + sql("INSERT INTO %s VALUES('%s')", tableName, uuid); + + assertThat(sql("SELECT uuid FROM %s", tableName)).hasSize(1).element(0).isEqualTo(row(uuid)); + } + + @TestTemplate + public void testCreateTableInRootNamespace() { + assumeThat(catalogName) + .as("Hadoop has no default namespace configured") + .isEqualTo("testhadoop"); + + try { + sql("CREATE TABLE %s.table (id bigint) USING iceberg", catalogName); + } finally { + sql("DROP TABLE IF EXISTS %s.table", catalogName); + } + } + + @TestTemplate + public void testCreateTableUsingParquet() { + assumeThat(catalogName) + .as("Not working with session catalog because Spark will not use v2 for a Parquet table") + .isNotEqualTo("spark_catalog"); + + assertThat(validationCatalog.tableExists(tableIdent)) + .as("Table should not already exist") + .isFalse(); + + sql("CREATE TABLE %s (id BIGINT NOT NULL, data STRING) USING parquet", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table).as("Should load the new table").isNotNull(); + + StructType expectedSchema = + StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get())); + assertThat(table.schema().asStruct()) + .as("Should have the expected schema") + .isEqualTo(expectedSchema); + assertThat(table.spec().fields()).as("Should not be partitioned").hasSize(0); + assertThat(table.properties().get(TableProperties.DEFAULT_FILE_FORMAT)) + .as("Should not have default format parquet") + .isEqualTo("parquet"); + + assertThatThrownBy( + () -> + sql( + "CREATE TABLE %s.default.fail (id BIGINT NOT NULL, data STRING) USING crocodile", + catalogName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Unsupported format in USING: crocodile"); + } + + @TestTemplate + public void testCreateTablePartitionedBy() { + assertThat(validationCatalog.tableExists(tableIdent)) + .as("Table should not already exist") + .isFalse(); + + sql( + "CREATE TABLE %s " + + "(id BIGINT NOT NULL, created_at TIMESTAMP, category STRING, data STRING) " + + "USING iceberg " + + "PARTITIONED BY (category, bucket(8, id), days(created_at))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table).as("Should load the new table").isNotNull(); + + StructType expectedSchema = + StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "created_at", Types.TimestampType.withZone()), + NestedField.optional(3, "category", Types.StringType.get()), + NestedField.optional(4, "data", Types.StringType.get())); + assertThat(table.schema().asStruct()) + .as("Should have the expected schema") + .isEqualTo(expectedSchema); + + PartitionSpec expectedSpec = + PartitionSpec.builderFor(new Schema(expectedSchema.fields())) + .identity("category") + .bucket("id", 8) + .day("created_at") + .build(); + assertThat(table.spec()).as("Should be partitioned correctly").isEqualTo(expectedSpec); + + assertThat(table.properties().get(TableProperties.DEFAULT_FILE_FORMAT)) + .as("Should not have the default format set") + .isNull(); + } + + @TestTemplate + public void testCreateTableColumnComments() { + assertThat(validationCatalog.tableExists(tableIdent)) + .as("Table should not already exist") + .isFalse(); + + sql( + "CREATE TABLE %s " + + "(id BIGINT NOT NULL COMMENT 'Unique identifier', data STRING COMMENT 'Data value') " + + "USING iceberg", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table).as("Should load the new table").isNotNull(); + + StructType expectedSchema = + StructType.of( + NestedField.required(1, "id", Types.LongType.get(), "Unique identifier"), + NestedField.optional(2, "data", Types.StringType.get(), "Data value")); + assertThat(table.schema().asStruct()) + .as("Should have the expected schema") + .isEqualTo(expectedSchema); + assertThat(table.spec().fields()).as("Should not be partitioned").hasSize(0); + assertThat(table.properties().get(TableProperties.DEFAULT_FILE_FORMAT)) + .as("Should not have the default format set") + .isNull(); + } + + @TestTemplate + public void testCreateTableComment() { + assertThat(validationCatalog.tableExists(tableIdent)) + .as("Table should not already exist") + .isFalse(); + + sql( + "CREATE TABLE %s " + + "(id BIGINT NOT NULL, data STRING) " + + "USING iceberg " + + "COMMENT 'Table doc'", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table).as("Should load the new table").isNotNull(); + + StructType expectedSchema = + StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get())); + assertThat(table.schema().asStruct()) + .as("Should have the expected schema") + .isEqualTo(expectedSchema); + assertThat(table.spec().fields()).as("Should not be partitioned").hasSize(0); + assertThat(table.properties().get(TableProperties.DEFAULT_FILE_FORMAT)) + .as("Should not have the default format set") + .isNull(); + assertThat(table.properties().get(TableCatalog.PROP_COMMENT)) + .as("Should have the table comment set in properties") + .isEqualTo("Table doc"); + } + + @TestTemplate + public void testCreateTableLocation() throws Exception { + assumeThat(validationCatalog) + .as("Cannot set custom locations for Hadoop catalog tables") + .isNotInstanceOf(HadoopCatalog.class); + + assertThat(validationCatalog.tableExists(tableIdent)) + .as("Table should not already exist") + .isFalse(); + + File tableLocation = Files.createTempDirectory(temp, "junit").toFile(); + assertThat(tableLocation.delete()).isTrue(); + + String location = "file:" + tableLocation; + + sql( + "CREATE TABLE %s " + + "(id BIGINT NOT NULL, data STRING) " + + "USING iceberg " + + "LOCATION '%s'", + tableName, location); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table).as("Should load the new table").isNotNull(); + + StructType expectedSchema = + StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get())); + assertThat(table.schema().asStruct()) + .as("Should have the expected schema") + .isEqualTo(expectedSchema); + assertThat(table.spec().fields()).as("Should not be partitioned").hasSize(0); + assertThat(table.properties().get(TableProperties.DEFAULT_FILE_FORMAT)) + .as("Should not have the default format set") + .isNull(); + assertThat(table.location()).as("Should have a custom table location").isEqualTo(location); + } + + @TestTemplate + public void testCreateTableProperties() { + assertThat(validationCatalog.tableExists(tableIdent)) + .as("Table should not already exist") + .isFalse(); + + sql( + "CREATE TABLE %s " + + "(id BIGINT NOT NULL, data STRING) " + + "USING iceberg " + + "TBLPROPERTIES (p1=2, p2='x')", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table).as("Should load the new table").isNotNull(); + + StructType expectedSchema = + StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get())); + assertThat(table.schema().asStruct()) + .as("Should have the expected schema") + .isEqualTo(expectedSchema); + assertThat(table.spec().fields()).as("Should not be partitioned").hasSize(0); + assertThat(table.properties()).containsEntry("p1", "2").containsEntry("p2", "x"); + } + + @TestTemplate + public void testCreateTableCommitProperties() { + assertThat(validationCatalog.tableExists(tableIdent)) + .as("Table should not already exist") + .isFalse(); + + assertThatThrownBy( + () -> + sql( + "CREATE TABLE %s " + + "(id BIGINT NOT NULL, data STRING) " + + "USING iceberg " + + "TBLPROPERTIES ('commit.retry.num-retries'='x', p2='x')", + tableName)) + .isInstanceOf(ValidationException.class) + .hasMessage("Table property commit.retry.num-retries must have integer value"); + + assertThatThrownBy( + () -> + sql( + "CREATE TABLE %s " + + "(id BIGINT NOT NULL, data STRING) " + + "USING iceberg " + + "TBLPROPERTIES ('commit.retry.max-wait-ms'='-1')", + tableName)) + .isInstanceOf(ValidationException.class) + .hasMessage("Table property commit.retry.max-wait-ms must have non negative integer value"); + + sql( + "CREATE TABLE %s " + + "(id BIGINT NOT NULL, data STRING) " + + "USING iceberg " + + "TBLPROPERTIES ('commit.retry.num-retries'='1', 'commit.retry.max-wait-ms'='3000')", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.properties()) + .containsEntry(TableProperties.COMMIT_NUM_RETRIES, "1") + .containsEntry(TableProperties.COMMIT_MAX_RETRY_WAIT_MS, "3000"); + } + + @TestTemplate + public void testCreateTableWithFormatV2ThroughTableProperty() { + assertThat(validationCatalog.tableExists(tableIdent)) + .as("Table should not already exist") + .isFalse(); + + sql( + "CREATE TABLE %s " + + "(id BIGINT NOT NULL, data STRING) " + + "USING iceberg " + + "TBLPROPERTIES ('format-version'='2')", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(((BaseTable) table).operations().current().formatVersion()) + .as("should create table using format v2") + .isEqualTo(2); + } + + @TestTemplate + public void testUpgradeTableWithFormatV2ThroughTableProperty() { + assertThat(validationCatalog.tableExists(tableIdent)) + .as("Table should not already exist") + .isFalse(); + + sql( + "CREATE TABLE %s " + + "(id BIGINT NOT NULL, data STRING) " + + "USING iceberg " + + "TBLPROPERTIES ('format-version'='1')", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + TableOperations ops = ((BaseTable) table).operations(); + assertThat(ops.refresh().formatVersion()) + .as("should create table using format v1") + .isEqualTo(1); + + sql("ALTER TABLE %s SET TBLPROPERTIES ('format-version'='2')", tableName); + assertThat(ops.refresh().formatVersion()) + .as("should update table to use format v2") + .isEqualTo(2); + } + + @TestTemplate + public void testDowngradeTableToFormatV1ThroughTablePropertyFails() { + assertThat(validationCatalog.tableExists(tableIdent)) + .as("Table should not already exist") + .isFalse(); + + sql( + "CREATE TABLE %s " + + "(id BIGINT NOT NULL, data STRING) " + + "USING iceberg " + + "TBLPROPERTIES ('format-version'='2')", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + TableOperations ops = ((BaseTable) table).operations(); + assertThat(ops.refresh().formatVersion()) + .as("should create table using format v2") + .isEqualTo(2); + + assertThatThrownBy( + () -> sql("ALTER TABLE %s SET TBLPROPERTIES ('format-version'='1')", tableName)) + .cause() + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot downgrade v2 table to v1"); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestCreateTableAsSelect.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestCreateTableAsSelect.java new file mode 100644 index 000000000000..4098a155be0d --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestCreateTableAsSelect.java @@ -0,0 +1,454 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.lit; +import static org.apache.spark.sql.functions.when; +import static org.assertj.core.api.Assertions.assertThat; + +import org.apache.iceberg.Parameter; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.spark.CatalogTestBase; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.types.Types; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestCreateTableAsSelect extends CatalogTestBase { + + @Parameter(index = 3) + private String sourceName; + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}, sourceName = {3}") + protected static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.HIVE.catalogName(), + SparkCatalogConfig.HIVE.implementation(), + SparkCatalogConfig.HIVE.properties(), + SparkCatalogConfig.HIVE.catalogName() + ".default.source" + }, + { + SparkCatalogConfig.HADOOP.catalogName(), + SparkCatalogConfig.HADOOP.implementation(), + SparkCatalogConfig.HADOOP.properties(), + SparkCatalogConfig.HADOOP.catalogName() + ".default.source" + }, + { + SparkCatalogConfig.SPARK.catalogName(), + SparkCatalogConfig.SPARK.implementation(), + SparkCatalogConfig.SPARK.properties(), + "default.source" + } + }; + } + + @BeforeEach + public void createTableIfNotExists() { + sql( + "CREATE TABLE IF NOT EXISTS %s (id bigint NOT NULL, data string) " + + "USING iceberg PARTITIONED BY (truncate(id, 3))", + sourceName); + sql("INSERT INTO %s VALUES (1, 'a'), (2, 'b'), (3, 'c'), (4, 'd'), (5, 'e')", sourceName); + } + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testUnpartitionedCTAS() { + sql("CREATE TABLE %s USING iceberg AS SELECT * FROM %s", tableName, sourceName); + + Schema expectedSchema = + new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get())); + + Table ctasTable = validationCatalog.loadTable(tableIdent); + + assertThat(ctasTable.schema().asStruct()) + .as("Should have expected nullable schema") + .isEqualTo(expectedSchema.asStruct()); + + assertThat(ctasTable.spec().fields()).as("Should be an unpartitioned table").hasSize(0); + assertEquals( + "Should have rows matching the source table", + sql("SELECT * FROM %s ORDER BY id", sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void testPartitionedCTAS() { + sql( + "CREATE TABLE %s USING iceberg PARTITIONED BY (id) AS SELECT * FROM %s ORDER BY id", + tableName, sourceName); + + Schema expectedSchema = + new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get())); + + PartitionSpec expectedSpec = PartitionSpec.builderFor(expectedSchema).identity("id").build(); + + Table ctasTable = validationCatalog.loadTable(tableIdent); + + assertThat(ctasTable.schema().asStruct()) + .as("Should have expected nullable schema") + .isEqualTo(expectedSchema.asStruct()); + assertThat(ctasTable.spec()).as("Should be partitioned by id").isEqualTo(expectedSpec); + assertEquals( + "Should have rows matching the source table", + sql("SELECT * FROM %s ORDER BY id", sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void testCTASWriteDistributionModeRespected() { + sql( + "CREATE TABLE %s USING iceberg PARTITIONED BY (bucket(2, id)) AS SELECT * FROM %s", + tableName, sourceName); + + Schema expectedSchema = + new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get())); + + PartitionSpec expectedSpec = PartitionSpec.builderFor(expectedSchema).bucket("id", 2).build(); + + Table ctasTable = validationCatalog.loadTable(tableIdent); + + assertThat(ctasTable.schema().asStruct()) + .as("Should have expected nullable schema") + .isEqualTo(expectedSchema.asStruct()); + assertThat(ctasTable.spec()).as("Should be partitioned by id").isEqualTo(expectedSpec); + assertEquals( + "Should have rows matching the source table", + sql("SELECT * FROM %s ORDER BY id", sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void testRTAS() { + sql( + "CREATE TABLE %s USING iceberg TBLPROPERTIES ('prop1'='val1', 'prop2'='val2')" + + "AS SELECT * FROM %s", + tableName, sourceName); + + assertEquals( + "Should have rows matching the source table", + sql("SELECT * FROM %s ORDER BY id", sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + sql( + "REPLACE TABLE %s USING iceberg PARTITIONED BY (part) TBLPROPERTIES ('prop1'='newval1', 'prop3'='val3') AS " + + "SELECT id, data, CASE WHEN (id %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY 3, 1", + tableName, sourceName); + + Schema expectedSchema = + new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get()), + Types.NestedField.optional(3, "part", Types.StringType.get())); + + PartitionSpec expectedSpec = + PartitionSpec.builderFor(expectedSchema).identity("part").withSpecId(1).build(); + + Table rtasTable = validationCatalog.loadTable(tableIdent); + + // the replacement table has a different schema and partition spec than the original + assertThat(rtasTable.schema().asStruct()) + .as("Should have expected nullable schema") + .isEqualTo(expectedSchema.asStruct()); + assertThat(rtasTable.spec()).as("Should be partitioned by part").isEqualTo(expectedSpec); + + assertEquals( + "Should have rows matching the source table", + sql( + "SELECT id, data, CASE WHEN (id %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY id", + sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + assertThat(rtasTable.snapshots()).as("Table should have expected snapshots").hasSize(2); + assertThat(rtasTable.properties().get("prop1")) + .as("Should have updated table property") + .isEqualTo("newval1"); + assertThat(rtasTable.properties().get("prop2")) + .as("Should have preserved table property") + .isEqualTo("val2"); + assertThat(rtasTable.properties().get("prop3")) + .as("Should have new table property") + .isEqualTo("val3"); + } + + @TestTemplate + public void testCreateRTAS() { + sql( + "CREATE OR REPLACE TABLE %s USING iceberg PARTITIONED BY (part) AS " + + "SELECT id, data, CASE WHEN (id %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY 3, 1", + tableName, sourceName); + + assertEquals( + "Should have rows matching the source table", + sql( + "SELECT id, data, CASE WHEN (id %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY id", + sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + sql( + "CREATE OR REPLACE TABLE %s USING iceberg PARTITIONED BY (part) AS " + + "SELECT 2 * id as id, data, CASE WHEN ((2 * id) %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY 3, 1", + tableName, sourceName); + + Schema expectedSchema = + new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get()), + Types.NestedField.optional(3, "part", Types.StringType.get())); + + PartitionSpec expectedSpec = + PartitionSpec.builderFor(expectedSchema) + .identity("part") + .withSpecId(0) // the spec is identical and should be reused + .build(); + + Table rtasTable = validationCatalog.loadTable(tableIdent); + + // the replacement table has a different schema and partition spec than the original + assertThat(rtasTable.schema().asStruct()) + .as("Should have expected nullable schema") + .isEqualTo(expectedSchema.asStruct()); + assertThat(rtasTable.spec()).as("Should be partitioned by part").isEqualTo(expectedSpec); + + assertEquals( + "Should have rows matching the source table", + sql( + "SELECT 2 * id, data, CASE WHEN ((2 * id) %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY id", + sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + assertThat(rtasTable.snapshots()).as("Table should have expected snapshots").hasSize(2); + } + + @TestTemplate + public void testDataFrameV2Create() throws Exception { + spark.table(sourceName).writeTo(tableName).using("iceberg").create(); + + Schema expectedSchema = + new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get())); + + Table ctasTable = validationCatalog.loadTable(tableIdent); + + assertThat(ctasTable.schema().asStruct()) + .as("Should have expected nullable schema") + .isEqualTo(expectedSchema.asStruct()); + assertThat(ctasTable.spec().fields()).as("Should be an unpartitioned table").hasSize(0); + assertEquals( + "Should have rows matching the source table", + sql("SELECT * FROM %s ORDER BY id", sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void testDataFrameV2Replace() throws Exception { + spark.table(sourceName).writeTo(tableName).using("iceberg").create(); + + assertEquals( + "Should have rows matching the source table", + sql("SELECT * FROM %s ORDER BY id", sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + spark + .table(sourceName) + .select( + col("id"), + col("data"), + when(col("id").mod(lit(2)).equalTo(lit(0)), lit("even")).otherwise("odd").as("part")) + .orderBy("part", "id") + .writeTo(tableName) + .partitionedBy(col("part")) + .using("iceberg") + .replace(); + + Schema expectedSchema = + new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get()), + Types.NestedField.optional(3, "part", Types.StringType.get())); + + PartitionSpec expectedSpec = + PartitionSpec.builderFor(expectedSchema).identity("part").withSpecId(1).build(); + + Table rtasTable = validationCatalog.loadTable(tableIdent); + + // the replacement table has a different schema and partition spec than the original + assertThat(rtasTable.schema().asStruct()) + .as("Should have expected nullable schema") + .isEqualTo(expectedSchema.asStruct()); + assertThat(rtasTable.spec()).as("Should be partitioned by part").isEqualTo(expectedSpec); + + assertEquals( + "Should have rows matching the source table", + sql( + "SELECT id, data, CASE WHEN (id %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY id", + sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + assertThat(rtasTable.snapshots()).as("Table should have expected snapshots").hasSize(2); + } + + @TestTemplate + public void testDataFrameV2CreateOrReplace() { + spark + .table(sourceName) + .select( + col("id"), + col("data"), + when(col("id").mod(lit(2)).equalTo(lit(0)), lit("even")).otherwise("odd").as("part")) + .orderBy("part", "id") + .writeTo(tableName) + .partitionedBy(col("part")) + .using("iceberg") + .createOrReplace(); + + assertEquals( + "Should have rows matching the source table", + sql( + "SELECT id, data, CASE WHEN (id %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY id", + sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + spark + .table(sourceName) + .select(col("id").multiply(lit(2)).as("id"), col("data")) + .select( + col("id"), + col("data"), + when(col("id").mod(lit(2)).equalTo(lit(0)), lit("even")).otherwise("odd").as("part")) + .orderBy("part", "id") + .writeTo(tableName) + .partitionedBy(col("part")) + .using("iceberg") + .createOrReplace(); + + Schema expectedSchema = + new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get()), + Types.NestedField.optional(3, "part", Types.StringType.get())); + + PartitionSpec expectedSpec = + PartitionSpec.builderFor(expectedSchema) + .identity("part") + .withSpecId(0) // the spec is identical and should be reused + .build(); + + Table rtasTable = validationCatalog.loadTable(tableIdent); + + // the replacement table has a different schema and partition spec than the original + assertThat(rtasTable.schema().asStruct()) + .as("Should have expected nullable schema") + .isEqualTo(expectedSchema.asStruct()); + assertThat(rtasTable.spec()).as("Should be partitioned by part").isEqualTo(expectedSpec); + + assertEquals( + "Should have rows matching the source table", + sql( + "SELECT 2 * id, data, CASE WHEN ((2 * id) %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY id", + sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + assertThat(rtasTable.snapshots()).as("Table should have expected snapshots").hasSize(2); + } + + @TestTemplate + public void testCreateRTASWithPartitionSpecChanging() { + sql( + "CREATE OR REPLACE TABLE %s USING iceberg PARTITIONED BY (part) AS " + + "SELECT id, data, CASE WHEN (id %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY 3, 1", + tableName, sourceName); + + Table rtasTable = validationCatalog.loadTable(tableIdent); + + assertEquals( + "Should have rows matching the source table", + sql( + "SELECT id, data, CASE WHEN (id %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY id", + sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + // Change the partitioning of the table + rtasTable.updateSpec().removeField("part").commit(); // Spec 1 + + sql( + "CREATE OR REPLACE TABLE %s USING iceberg PARTITIONED BY (part, id) AS " + + "SELECT 2 * id as id, data, CASE WHEN ((2 * id) %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY 3, 1", + tableName, sourceName); + + Schema expectedSchema = + new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get()), + Types.NestedField.optional(3, "part", Types.StringType.get())); + + PartitionSpec expectedSpec = + PartitionSpec.builderFor(expectedSchema) + .identity("part") + .identity("id") + .withSpecId(2) // The Spec is new + .build(); + + assertThat(rtasTable.spec()).as("Should be partitioned by part and id").isEqualTo(expectedSpec); + + // the replacement table has a different schema and partition spec than the original + assertThat(rtasTable.schema().asStruct()) + .as("Should have expected nullable schema") + .isEqualTo(expectedSchema.asStruct()); + + assertEquals( + "Should have rows matching the source table", + sql( + "SELECT 2 * id, data, CASE WHEN ((2 * id) %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY id", + sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + assertThat(rtasTable.snapshots()).as("Table should have expected snapshots").hasSize(2); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestDeleteFrom.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestDeleteFrom.java new file mode 100644 index 000000000000..7706c5aad4de --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestDeleteFrom.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.List; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.CatalogTestBase; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestDeleteFrom extends CatalogTestBase { + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testDeleteFromUnpartitionedTable() throws NoSuchTableException { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.coalesce(1).writeTo(tableName).append(); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(2L, "b"), row(3L, "c")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + sql("DELETE FROM %s WHERE id < 2", tableName); + + assertEquals( + "Should have no rows after successful delete", + ImmutableList.of(row(2L, "b"), row(3L, "c")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + sql("DELETE FROM %s WHERE id < 4", tableName); + + assertEquals( + "Should have no rows after successful delete", + ImmutableList.of(), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void testDeleteFromTableAtSnapshot() throws NoSuchTableException { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.coalesce(1).writeTo(tableName).append(); + + long snapshotId = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); + String prefix = "snapshot_id_"; + assertThatThrownBy(() -> sql("DELETE FROM %s.%s WHERE id < 4", tableName, prefix + snapshotId)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageStartingWith("Cannot delete from table at a specific snapshot"); + } + + @TestTemplate + public void testDeleteFromPartitionedTable() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id bigint, data string) " + + "USING iceberg " + + "PARTITIONED BY (truncate(id, 2))", + tableName); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.coalesce(1).writeTo(tableName).append(); + + assertEquals( + "Should have 3 rows in 2 partitions", + ImmutableList.of(row(1L, "a"), row(2L, "b"), row(3L, "c")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + sql("DELETE FROM %s WHERE id > 2", tableName); + assertEquals( + "Should have two rows in the second partition", + ImmutableList.of(row(1L, "a"), row(2L, "b")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + sql("DELETE FROM %s WHERE id < 2", tableName); + + assertEquals( + "Should have two rows in the second partition", + ImmutableList.of(row(2L, "b")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void testDeleteFromWhereFalse() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a'), (2, 'b'), (3, 'c')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(2L, "b"), row(3L, "c")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).as("Should have 1 snapshot").hasSize(1); + + sql("DELETE FROM %s WHERE false", tableName); + + table.refresh(); + + assertThat(table.snapshots()).as("Delete should not produce a new snapshot").hasSize(1); + } + + @TestTemplate + public void testTruncate() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a'), (2, 'b'), (3, 'c')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(2L, "b"), row(3L, "c")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(table.snapshots()).as("Should have 1 snapshot").hasSize(1); + + sql("TRUNCATE TABLE %s", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestDropTable.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestDropTable.java new file mode 100644 index 000000000000..ec8308e1c772 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestDropTable.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.io.IOException; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.MetadataTableType; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Streams; +import org.apache.iceberg.spark.CatalogTestBase; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestDropTable extends CatalogTestBase { + + @BeforeEach + public void createTable() { + sql("CREATE TABLE %s (id INT, name STRING) USING iceberg", tableName); + sql("INSERT INTO %s VALUES (1, 'test')", tableName); + } + + @AfterEach + public void removeTable() throws IOException { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testDropTable() throws IOException { + dropTableInternal(); + } + + @TestTemplate + public void testDropTableGCDisabled() throws IOException { + sql("ALTER TABLE %s SET TBLPROPERTIES (gc.enabled = false)", tableName); + dropTableInternal(); + } + + private void dropTableInternal() throws IOException { + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "test")), + sql("SELECT * FROM %s", tableName)); + + List manifestAndFiles = manifestsAndFiles(); + assertThat(manifestAndFiles).as("There should be 2 files for manifests and files").hasSize(2); + assertThat(checkFilesExist(manifestAndFiles, true)).as("All files should exist").isTrue(); + + sql("DROP TABLE %s", tableName); + assertThat(validationCatalog.tableExists(tableIdent)).as("Table should not exist").isFalse(); + + if (catalogName.equals("testhadoop")) { + // HadoopCatalog drop table without purge will delete the base table location. + assertThat(checkFilesExist(manifestAndFiles, false)) + .as("All files should be deleted") + .isTrue(); + } else { + assertThat(checkFilesExist(manifestAndFiles, true)) + .as("All files should not be deleted") + .isTrue(); + } + } + + @TestTemplate + public void testPurgeTable() throws IOException { + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "test")), + sql("SELECT * FROM %s", tableName)); + + List manifestAndFiles = manifestsAndFiles(); + assertThat(manifestAndFiles).as("There should be 2 files for manifests and files").hasSize(2); + assertThat(checkFilesExist(manifestAndFiles, true)).as("All files should exist").isTrue(); + + sql("DROP TABLE %s PURGE", tableName); + assertThat(validationCatalog.tableExists(tableIdent)).as("Table should not exist").isFalse(); + assertThat(checkFilesExist(manifestAndFiles, false)).as("All files should be deleted").isTrue(); + } + + @TestTemplate + public void testPurgeTableGCDisabled() throws IOException { + sql("ALTER TABLE %s SET TBLPROPERTIES (gc.enabled = false)", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "test")), + sql("SELECT * FROM %s", tableName)); + + List manifestAndFiles = manifestsAndFiles(); + assertThat(manifestAndFiles).as("There should be 2 files for manifests and files").hasSize(2); + assertThat(checkFilesExist(manifestAndFiles, true)).as("All files should exist").isTrue(); + + assertThatThrownBy(() -> sql("DROP TABLE %s PURGE", tableName)) + .isInstanceOf(ValidationException.class) + .hasMessageContaining( + "Cannot purge table: GC is disabled (deleting files may corrupt other tables"); + + assertThat(validationCatalog.tableExists(tableIdent)) + .as("Table should not been dropped") + .isTrue(); + assertThat(checkFilesExist(manifestAndFiles, true)) + .as("All files should not be deleted") + .isTrue(); + } + + private List manifestsAndFiles() { + List files = sql("SELECT file_path FROM %s.%s", tableName, MetadataTableType.FILES); + List manifests = + sql("SELECT path FROM %s.%s", tableName, MetadataTableType.MANIFESTS); + return Streams.concat(files.stream(), manifests.stream()) + .map(row -> (String) row[0]) + .collect(Collectors.toList()); + } + + private boolean checkFilesExist(List files, boolean shouldExist) throws IOException { + boolean mask = !shouldExist; + if (files.isEmpty()) { + return mask; + } + + FileSystem fs = new Path(files.get(0)).getFileSystem(hiveConf); + return files.stream() + .allMatch( + file -> { + try { + return fs.exists(new Path(file)) ^ mask; + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestFilterPushDown.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestFilterPushDown.java new file mode 100644 index 000000000000..9d2ce2b388a2 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestFilterPushDown.java @@ -0,0 +1,616 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import static org.apache.iceberg.PlanningMode.DISTRIBUTED; +import static org.apache.iceberg.PlanningMode.LOCAL; +import static org.assertj.core.api.Assertions.assertThat; + +import java.math.BigDecimal; +import java.sql.Timestamp; +import java.time.Instant; +import java.util.List; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.PlanningMode; +import org.apache.iceberg.Table; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.TestBaseWithCatalog; +import org.apache.spark.sql.execution.SparkPlan; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestFilterPushDown extends TestBaseWithCatalog { + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}, planningMode = {0}") + public static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.HADOOP.catalogName(), + SparkCatalogConfig.HADOOP.implementation(), + SparkCatalogConfig.HADOOP.properties(), + LOCAL + }, + { + SparkCatalogConfig.HADOOP.catalogName(), + SparkCatalogConfig.HADOOP.implementation(), + SparkCatalogConfig.HADOOP.properties(), + DISTRIBUTED + } + }; + } + + @Parameter(index = 3) + private PlanningMode planningMode; + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS tmp_view"); + } + + @TestTemplate + public void testFilterPushdownWithDecimalValues() { + sql( + "CREATE TABLE %s (id INT, salary DECIMAL(10, 2), dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (dep)", + tableName); + configurePlanningMode(planningMode); + + sql("INSERT INTO %s VALUES (1, 100.01, 'd1')", tableName); + sql("INSERT INTO %s VALUES (2, 100.05, 'd1')", tableName); + + checkFilters( + "dep = 'd1' AND salary > 100.03" /* query predicate */, + "isnotnull(salary) AND (salary > 100.03)" /* Spark post scan filter */, + "dep IS NOT NULL, salary IS NOT NULL, dep = 'd1', salary > 100.03" /* Iceberg scan filters */, + ImmutableList.of(row(2, new BigDecimal("100.05"), "d1"))); + } + + @TestTemplate + public void testFilterPushdownWithIdentityTransform() { + sql( + "CREATE TABLE %s (id INT, salary INT, dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (dep)", + tableName); + configurePlanningMode(planningMode); + + sql("INSERT INTO %s VALUES (1, 100, 'd1')", tableName); + sql("INSERT INTO %s VALUES (2, 200, 'd2')", tableName); + sql("INSERT INTO %s VALUES (3, 300, 'd3')", tableName); + sql("INSERT INTO %s VALUES (4, 400, 'd4')", tableName); + sql("INSERT INTO %s VALUES (5, 500, 'd5')", tableName); + sql("INSERT INTO %s VALUES (6, 600, null)", tableName); + + checkOnlyIcebergFilters( + "dep IS NULL" /* query predicate */, + "dep IS NULL" /* Iceberg scan filters */, + ImmutableList.of(row(6, 600, null))); + + checkOnlyIcebergFilters( + "dep IS NOT NULL" /* query predicate */, + "dep IS NOT NULL" /* Iceberg scan filters */, + ImmutableList.of( + row(1, 100, "d1"), + row(2, 200, "d2"), + row(3, 300, "d3"), + row(4, 400, "d4"), + row(5, 500, "d5"))); + + checkOnlyIcebergFilters( + "dep = 'd3'" /* query predicate */, + "dep IS NOT NULL, dep = 'd3'" /* Iceberg scan filters */, + ImmutableList.of(row(3, 300, "d3"))); + + checkOnlyIcebergFilters( + "dep > 'd3'" /* query predicate */, + "dep IS NOT NULL, dep > 'd3'" /* Iceberg scan filters */, + ImmutableList.of(row(4, 400, "d4"), row(5, 500, "d5"))); + + checkOnlyIcebergFilters( + "dep >= 'd5'" /* query predicate */, + "dep IS NOT NULL, dep >= 'd5'" /* Iceberg scan filters */, + ImmutableList.of(row(5, 500, "d5"))); + + checkOnlyIcebergFilters( + "dep < 'd2'" /* query predicate */, + "dep IS NOT NULL, dep < 'd2'" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1"))); + + checkOnlyIcebergFilters( + "dep <= 'd2'" /* query predicate */, + "dep IS NOT NULL, dep <= 'd2'" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1"), row(2, 200, "d2"))); + + checkOnlyIcebergFilters( + "dep <=> 'd3'" /* query predicate */, + "dep = 'd3'" /* Iceberg scan filters */, + ImmutableList.of(row(3, 300, "d3"))); + + checkOnlyIcebergFilters( + "dep IN (null, 'd1')" /* query predicate */, + "dep IN ('d1')" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1"))); + + checkOnlyIcebergFilters( + "dep NOT IN ('d2', 'd4')" /* query predicate */, + "(dep IS NOT NULL AND dep NOT IN ('d2', 'd4'))" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1"), row(3, 300, "d3"), row(5, 500, "d5"))); + + checkOnlyIcebergFilters( + "dep = 'd1' AND dep IS NOT NULL" /* query predicate */, + "dep = 'd1', dep IS NOT NULL" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1"))); + + checkOnlyIcebergFilters( + "dep = 'd1' OR dep = 'd2' OR dep = 'd3'" /* query predicate */, + "((dep = 'd1' OR dep = 'd2') OR dep = 'd3')" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1"), row(2, 200, "d2"), row(3, 300, "d3"))); + + checkFilters( + "dep = 'd1' AND id = 1" /* query predicate */, + "isnotnull(id) AND (id = 1)" /* Spark post scan filter */, + "dep IS NOT NULL, id IS NOT NULL, dep = 'd1', id = 1" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1"))); + + checkFilters( + "dep = 'd2' OR id = 1" /* query predicate */, + "(dep = d2) OR (id = 1)" /* Spark post scan filter */, + "(dep = 'd2' OR id = 1)" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1"), row(2, 200, "d2"))); + + checkFilters( + "dep LIKE 'd1%' AND id = 1" /* query predicate */, + "isnotnull(id) AND (id = 1)" /* Spark post scan filter */, + "dep IS NOT NULL, id IS NOT NULL, dep LIKE 'd1%', id = 1" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1"))); + + checkFilters( + "dep NOT LIKE 'd5%' AND (id = 1 OR id = 5)" /* query predicate */, + "(id = 1) OR (id = 5)" /* Spark post scan filter */, + "dep IS NOT NULL, NOT (dep LIKE 'd5%'), (id = 1 OR id = 5)" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1"))); + + checkFilters( + "dep LIKE '%d5' AND id IN (1, 5)" /* query predicate */, + "EndsWith(dep, d5) AND id IN (1,5)" /* Spark post scan filter */, + "dep IS NOT NULL, id IN (1, 5)" /* Iceberg scan filters */, + ImmutableList.of(row(5, 500, "d5"))); + } + + @TestTemplate + public void testFilterPushdownWithHoursTransform() { + sql( + "CREATE TABLE %s (id INT, price INT, t TIMESTAMP)" + + "USING iceberg " + + "PARTITIONED BY (hours(t))", + tableName); + configurePlanningMode(planningMode); + + sql("INSERT INTO %s VALUES (1, 100, TIMESTAMP '2021-06-30T01:00:00.000Z')", tableName); + sql("INSERT INTO %s VALUES (2, 200, TIMESTAMP '2021-06-30T02:00:00.000Z')", tableName); + sql("INSERT INTO %s VALUES (3, 300, null)", tableName); + + withDefaultTimeZone( + "UTC", + () -> { + checkOnlyIcebergFilters( + "t IS NULL" /* query predicate */, + "t IS NULL" /* Iceberg scan filters */, + ImmutableList.of(row(3, 300, null))); + + // strict/inclusive projections for t < TIMESTAMP '2021-06-30T02:00:00.000Z' are equal, + // so this filter selects entire partitions and can be pushed down completely + checkOnlyIcebergFilters( + "t < TIMESTAMP '2021-06-30T02:00:00.000Z'" /* query predicate */, + "t IS NOT NULL, t < 1625018400000000" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, timestamp("2021-06-30T01:00:00.0Z")))); + + // strict/inclusive projections for t < TIMESTAMP '2021-06-30T01:00:00.001Z' differ, + // so this filter does NOT select entire partitions and can't be pushed down completely + checkFilters( + "t < TIMESTAMP '2021-06-30T01:00:00.001Z'" /* query predicate */, + "t < 2021-06-30 01:00:00.001" /* Spark post scan filter */, + "t IS NOT NULL, t < 1625014800001000" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, timestamp("2021-06-30T01:00:00.0Z")))); + + // strict/inclusive projections for t <= TIMESTAMP '2021-06-30T01:00:00.000Z' differ, + // so this filter does NOT select entire partitions and can't be pushed down completely + checkFilters( + "t <= TIMESTAMP '2021-06-30T01:00:00.000Z'" /* query predicate */, + "t <= 2021-06-30 01:00:00" /* Spark post scan filter */, + "t IS NOT NULL, t <= 1625014800000000" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, timestamp("2021-06-30T01:00:00.0Z")))); + }); + } + + @TestTemplate + public void testFilterPushdownWithDaysTransform() { + sql( + "CREATE TABLE %s (id INT, price INT, t TIMESTAMP)" + + "USING iceberg " + + "PARTITIONED BY (days(t))", + tableName); + configurePlanningMode(planningMode); + + sql("INSERT INTO %s VALUES (1, 100, TIMESTAMP '2021-06-15T01:00:00.000Z')", tableName); + sql("INSERT INTO %s VALUES (2, 200, TIMESTAMP '2021-06-30T02:00:00.000Z')", tableName); + sql("INSERT INTO %s VALUES (3, 300, TIMESTAMP '2021-07-15T10:00:00.000Z')", tableName); + sql("INSERT INTO %s VALUES (4, 400, null)", tableName); + + withDefaultTimeZone( + "UTC", + () -> { + checkOnlyIcebergFilters( + "t IS NULL" /* query predicate */, + "t IS NULL" /* Iceberg scan filters */, + ImmutableList.of(row(4, 400, null))); + + // strict/inclusive projections for t < TIMESTAMP '2021-07-05T00:00:00.000Z' are equal, + // so this filter selects entire partitions and can be pushed down completely + checkOnlyIcebergFilters( + "t < TIMESTAMP '2021-07-05T00:00:00.000Z'" /* query predicate */, + "t IS NOT NULL, t < 1625443200000000" /* Iceberg scan filters */, + ImmutableList.of( + row(1, 100, timestamp("2021-06-15T01:00:00.000Z")), + row(2, 200, timestamp("2021-06-30T02:00:00.000Z")))); + + // strict/inclusive projections for t < TIMESTAMP '2021-06-30T03:00:00.000Z' differ, + // so this filter does NOT select entire partitions and can't be pushed down completely + checkFilters( + "t < TIMESTAMP '2021-06-30T03:00:00.000Z'" /* query predicate */, + "t < 2021-06-30 03:00:00" /* Spark post scan filter */, + "t IS NOT NULL, t < 1625022000000000" /* Iceberg scan filters */, + ImmutableList.of( + row(1, 100, timestamp("2021-06-15T01:00:00.000Z")), + row(2, 200, timestamp("2021-06-30T02:00:00.000Z")))); + }); + } + + @TestTemplate + public void testFilterPushdownWithMonthsTransform() { + sql( + "CREATE TABLE %s (id INT, price INT, t TIMESTAMP)" + + "USING iceberg " + + "PARTITIONED BY (months(t))", + tableName); + configurePlanningMode(planningMode); + + sql("INSERT INTO %s VALUES (1, 100, TIMESTAMP '2021-06-30T01:00:00.000Z')", tableName); + sql("INSERT INTO %s VALUES (2, 200, TIMESTAMP '2021-06-30T02:00:00.000Z')", tableName); + sql("INSERT INTO %s VALUES (3, 300, TIMESTAMP '2021-07-15T10:00:00.000Z')", tableName); + sql("INSERT INTO %s VALUES (4, 400, null)", tableName); + + withDefaultTimeZone( + "UTC", + () -> { + checkOnlyIcebergFilters( + "t IS NULL" /* query predicate */, + "t IS NULL" /* Iceberg scan filters */, + ImmutableList.of(row(4, 400, null))); + + // strict/inclusive projections for t < TIMESTAMP '2021-07-01T00:00:00.000Z' are equal, + // so this filter selects entire partitions and can be pushed down completely + checkOnlyIcebergFilters( + "t < TIMESTAMP '2021-07-01T00:00:00.000Z'" /* query predicate */, + "t IS NOT NULL, t < 1625097600000000" /* Iceberg scan filters */, + ImmutableList.of( + row(1, 100, timestamp("2021-06-30T01:00:00.000Z")), + row(2, 200, timestamp("2021-06-30T02:00:00.000Z")))); + + // strict/inclusive projections for t < TIMESTAMP '2021-06-30T03:00:00.000Z' differ, + // so this filter does NOT select entire partitions and can't be pushed down completely + checkFilters( + "t < TIMESTAMP '2021-06-30T03:00:00.000Z'" /* query predicate */, + "t < 2021-06-30 03:00:00" /* Spark post scan filter */, + "t IS NOT NULL, t < 1625022000000000" /* Iceberg scan filters */, + ImmutableList.of( + row(1, 100, timestamp("2021-06-30T01:00:00.000Z")), + row(2, 200, timestamp("2021-06-30T02:00:00.000Z")))); + }); + } + + @TestTemplate + public void testFilterPushdownWithYearsTransform() { + sql( + "CREATE TABLE %s (id INT, price INT, t TIMESTAMP)" + + "USING iceberg " + + "PARTITIONED BY (years(t))", + tableName); + configurePlanningMode(planningMode); + + sql("INSERT INTO %s VALUES (1, 100, TIMESTAMP '2021-06-30T01:00:00.000Z')", tableName); + sql("INSERT INTO %s VALUES (2, 200, TIMESTAMP '2021-06-30T02:00:00.000Z')", tableName); + sql("INSERT INTO %s VALUES (2, 200, TIMESTAMP '2022-09-25T02:00:00.000Z')", tableName); + sql("INSERT INTO %s VALUES (3, 300, null)", tableName); + + withDefaultTimeZone( + "UTC", + () -> { + checkOnlyIcebergFilters( + "t IS NULL" /* query predicate */, + "t IS NULL" /* Iceberg scan filters */, + ImmutableList.of(row(3, 300, null))); + + // strict/inclusive projections for t < TIMESTAMP '2022-01-01T00:00:00.000Z' are equal, + // so this filter selects entire partitions and can be pushed down completely + checkOnlyIcebergFilters( + "t < TIMESTAMP '2022-01-01T00:00:00.000Z'" /* query predicate */, + "t IS NOT NULL, t < 1640995200000000" /* Iceberg scan filters */, + ImmutableList.of( + row(1, 100, timestamp("2021-06-30T01:00:00.000Z")), + row(2, 200, timestamp("2021-06-30T02:00:00.000Z")))); + + // strict/inclusive projections for t < TIMESTAMP '2021-06-30T03:00:00.000Z' differ, + // so this filter does NOT select entire partitions and can't be pushed down completely + checkFilters( + "t < TIMESTAMP '2021-06-30T03:00:00.000Z'" /* query predicate */, + "t < 2021-06-30 03:00:00" /* Spark post scan filter */, + "t IS NOT NULL, t < 1625022000000000" /* Iceberg scan filters */, + ImmutableList.of( + row(1, 100, timestamp("2021-06-30T01:00:00.000Z")), + row(2, 200, timestamp("2021-06-30T02:00:00.000Z")))); + }); + } + + @TestTemplate + public void testFilterPushdownWithBucketTransform() { + sql( + "CREATE TABLE %s (id INT, salary INT, dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (dep, bucket(8, id))", + tableName); + configurePlanningMode(planningMode); + + sql("INSERT INTO %s VALUES (1, 100, 'd1')", tableName); + sql("INSERT INTO %s VALUES (2, 200, 'd2')", tableName); + + checkFilters( + "dep = 'd1' AND id = 1" /* query predicate */, + "id = 1" /* Spark post scan filter */, + "dep IS NOT NULL, id IS NOT NULL, dep = 'd1'" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1"))); + } + + @TestTemplate + public void testFilterPushdownWithTruncateTransform() { + sql( + "CREATE TABLE %s (id INT, salary INT, dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (truncate(1, dep))", + tableName); + configurePlanningMode(planningMode); + + sql("INSERT INTO %s VALUES (1, 100, 'd1')", tableName); + sql("INSERT INTO %s VALUES (2, 200, 'd2')", tableName); + sql("INSERT INTO %s VALUES (3, 300, 'a3')", tableName); + + checkOnlyIcebergFilters( + "dep LIKE 'd%'" /* query predicate */, + "dep IS NOT NULL, dep LIKE 'd%'" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1"), row(2, 200, "d2"))); + + checkFilters( + "dep = 'd1'" /* query predicate */, + "dep = d1" /* Spark post scan filter */, + "dep IS NOT NULL" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1"))); + } + + @TestTemplate + public void testFilterPushdownWithSpecEvolutionAndIdentityTransforms() { + sql( + "CREATE TABLE %s (id INT, salary INT, dep STRING, sub_dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (dep)", + tableName); + configurePlanningMode(planningMode); + + sql("INSERT INTO %s VALUES (1, 100, 'd1', 'sd1')", tableName); + + // the filter can be pushed completely because all specs include identity(dep) + checkOnlyIcebergFilters( + "dep = 'd1'" /* query predicate */, + "dep IS NOT NULL, dep = 'd1'" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1", "sd1"))); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateSpec().addField("sub_dep").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO %s VALUES (2, 200, 'd2', 'sd2')", tableName); + + // the filter can be pushed completely because all specs include identity(dep) + checkOnlyIcebergFilters( + "dep = 'd1'" /* query predicate */, + "dep IS NOT NULL, dep = 'd1'" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1", "sd1"))); + + table.updateSpec().removeField("sub_dep").removeField("dep").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO %s VALUES (3, 300, 'd3', 'sd3')", tableName); + + // the filter can't be pushed completely because not all specs include identity(dep) + checkFilters( + "dep = 'd1'" /* query predicate */, + "isnotnull(dep) AND (dep = d1)" /* Spark post scan filter */, + "dep IS NOT NULL, dep = 'd1'" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1", "sd1"))); + } + + @TestTemplate + public void testFilterPushdownWithSpecEvolutionAndTruncateTransform() { + sql( + "CREATE TABLE %s (id INT, salary INT, dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (truncate(2, dep))", + tableName); + configurePlanningMode(planningMode); + + sql("INSERT INTO %s VALUES (1, 100, 'd1')", tableName); + + // the filter can be pushed completely because the current spec supports it + checkOnlyIcebergFilters( + "dep LIKE 'd1%'" /* query predicate */, + "dep IS NOT NULL, dep LIKE 'd1%'" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1"))); + + Table table = validationCatalog.loadTable(tableIdent); + table + .updateSpec() + .removeField(Expressions.truncate("dep", 2)) + .addField(Expressions.truncate("dep", 1)) + .commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO %s VALUES (2, 200, 'd2')", tableName); + + // the filter can be pushed completely because both specs support it + checkOnlyIcebergFilters( + "dep LIKE 'd%'" /* query predicate */, + "dep IS NOT NULL, dep LIKE 'd%'" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1"), row(2, 200, "d2"))); + + // the filter can't be pushed completely because the second spec is truncate(dep, 1) and + // the predicate literal is d1, which is two chars + checkFilters( + "dep LIKE 'd1%' AND id = 1" /* query predicate */, + "(isnotnull(id) AND StartsWith(dep, d1)) AND (id = 1)" /* Spark post scan filter */, + "dep IS NOT NULL, id IS NOT NULL, dep LIKE 'd1%', id = 1" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1"))); + } + + @TestTemplate + public void testFilterPushdownWithSpecEvolutionAndTimeTransforms() { + sql( + "CREATE TABLE %s (id INT, price INT, t TIMESTAMP)" + + "USING iceberg " + + "PARTITIONED BY (hours(t))", + tableName); + configurePlanningMode(planningMode); + + withDefaultTimeZone( + "UTC", + () -> { + sql("INSERT INTO %s VALUES (1, 100, TIMESTAMP '2021-06-30T01:00:00.000Z')", tableName); + + // the filter can be pushed completely because the current spec supports it + checkOnlyIcebergFilters( + "t < TIMESTAMP '2021-07-01T00:00:00.000Z'" /* query predicate */, + "t IS NOT NULL, t < 1625097600000000" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, timestamp("2021-06-30T01:00:00.000Z")))); + + Table table = validationCatalog.loadTable(tableIdent); + table + .updateSpec() + .removeField(Expressions.hour("t")) + .addField(Expressions.month("t")) + .commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO %s VALUES (2, 200, TIMESTAMP '2021-05-30T01:00:00.000Z')", tableName); + + // the filter can be pushed completely because both specs support it + checkOnlyIcebergFilters( + "t < TIMESTAMP '2021-06-01T00:00:00.000Z'" /* query predicate */, + "t IS NOT NULL, t < 1622505600000000" /* Iceberg scan filters */, + ImmutableList.of(row(2, 200, timestamp("2021-05-30T01:00:00.000Z")))); + }); + } + + @TestTemplate + public void testFilterPushdownWithSpecialFloatingPointPartitionValues() { + sql( + "CREATE TABLE %s (id INT, salary DOUBLE)" + "USING iceberg " + "PARTITIONED BY (salary)", + tableName); + configurePlanningMode(planningMode); + + sql("INSERT INTO %s VALUES (1, 100.5)", tableName); + sql("INSERT INTO %s VALUES (2, double('NaN'))", tableName); + sql("INSERT INTO %s VALUES (3, double('infinity'))", tableName); + sql("INSERT INTO %s VALUES (4, double('-infinity'))", tableName); + + checkOnlyIcebergFilters( + "salary = 100.5" /* query predicate */, + "salary IS NOT NULL, salary = 100.5" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100.5))); + + checkOnlyIcebergFilters( + "salary = double('NaN')" /* query predicate */, + "salary IS NOT NULL, is_nan(salary)" /* Iceberg scan filters */, + ImmutableList.of(row(2, Double.NaN))); + + checkOnlyIcebergFilters( + "salary != double('NaN')" /* query predicate */, + "salary IS NOT NULL, NOT (is_nan(salary))" /* Iceberg scan filters */, + ImmutableList.of( + row(1, 100.5), row(3, Double.POSITIVE_INFINITY), row(4, Double.NEGATIVE_INFINITY))); + + checkOnlyIcebergFilters( + "salary = double('infinity')" /* query predicate */, + "salary IS NOT NULL, salary = Infinity" /* Iceberg scan filters */, + ImmutableList.of(row(3, Double.POSITIVE_INFINITY))); + + checkOnlyIcebergFilters( + "salary = double('-infinity')" /* query predicate */, + "salary IS NOT NULL, salary = -Infinity" /* Iceberg scan filters */, + ImmutableList.of(row(4, Double.NEGATIVE_INFINITY))); + } + + private void checkOnlyIcebergFilters( + String predicate, String icebergFilters, List expectedRows) { + + checkFilters(predicate, null, icebergFilters, expectedRows); + } + + private void checkFilters( + String predicate, String sparkFilter, String icebergFilters, List expectedRows) { + + Action check = + () -> { + assertEquals( + "Rows must match", + expectedRows, + sql("SELECT * FROM %s WHERE %s ORDER BY id", tableName, predicate)); + }; + SparkPlan sparkPlan = executeAndKeepPlan(check); + String planAsString = sparkPlan.toString().replaceAll("#(\\d+L?)", ""); + + if (sparkFilter != null) { + assertThat(planAsString) + .as("Post scan filter should match") + .contains("Filter (" + sparkFilter + ")"); + } else { + assertThat(planAsString).as("Should be no post scan filter").doesNotContain("Filter ("); + } + + assertThat(planAsString) + .as("Pushed filters must match") + .contains("[filters=" + icebergFilters + ","); + } + + private Timestamp timestamp(String timestampAsString) { + return Timestamp.from(Instant.parse(timestampAsString)); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestNamespaceSQL.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestNamespaceSQL.java new file mode 100644 index 000000000000..0ba480692523 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestNamespaceSQL.java @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assumptions.assumeThat; + +import java.io.File; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.exceptions.NamespaceNotEmptyException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.spark.CatalogTestBase; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestNamespaceSQL extends CatalogTestBase { + private static final Namespace NS = Namespace.of("db"); + + @Parameter(index = 3) + private String fullNamespace; + + @Parameter(index = 4) + private boolean isHadoopCatalog; + + @Parameters( + name = + "catalogName = {0}, implementation = {1}, config = {2}, fullNameSpace = {3}, isHadoopCatalog = {4}") + protected static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.HIVE.catalogName(), + SparkCatalogConfig.HIVE.implementation(), + SparkCatalogConfig.HIVE.properties(), + SparkCatalogConfig.HIVE.catalogName() + "." + NS.toString(), + false + }, + { + SparkCatalogConfig.HADOOP.catalogName(), + SparkCatalogConfig.HADOOP.implementation(), + SparkCatalogConfig.HADOOP.properties(), + SparkCatalogConfig.HADOOP.catalogName() + "." + NS, + true + }, + { + SparkCatalogConfig.SPARK.catalogName(), + SparkCatalogConfig.SPARK.implementation(), + SparkCatalogConfig.SPARK.properties(), + NS.toString(), + false + } + }; + } + + @AfterEach + public void cleanNamespaces() { + sql("DROP TABLE IF EXISTS %s.table", fullNamespace); + sql("DROP NAMESPACE IF EXISTS %s", fullNamespace); + } + + @TestTemplate + public void testCreateNamespace() { + assertThat(validationNamespaceCatalog.namespaceExists(NS)) + .as("Namespace should not already exist") + .isFalse(); + + sql("CREATE NAMESPACE %s", fullNamespace); + + assertThat(validationNamespaceCatalog.namespaceExists(NS)) + .as("Namespace should exist") + .isTrue(); + } + + @TestTemplate + public void testDefaultNamespace() { + assumeThat(isHadoopCatalog).as("Hadoop has no default namespace configured").isFalse(); + + sql("USE %s", catalogName); + + Object[] current = Iterables.getOnlyElement(sql("SHOW CURRENT NAMESPACE")); + assertThat(current[0]).as("Should use the current catalog").isEqualTo(catalogName); + assertThat(current[1]).as("Should use the configured default namespace").isEqualTo("default"); + } + + @TestTemplate + public void testDropEmptyNamespace() { + assertThat(validationNamespaceCatalog.namespaceExists(NS)) + .as("Namespace should not already exist") + .isFalse(); + + sql("CREATE NAMESPACE %s", fullNamespace); + + assertThat(validationNamespaceCatalog.namespaceExists(NS)) + .as("Namespace should exist") + .isTrue(); + + sql("DROP NAMESPACE %s", fullNamespace); + + assertThat(validationNamespaceCatalog.namespaceExists(NS)) + .as("Namespace should have been dropped") + .isFalse(); + } + + @TestTemplate + public void testDropNonEmptyNamespace() { + assumeThat(catalogName).as("Session catalog has flaky behavior").isNotEqualTo("spark_catalog"); + + assertThat(validationNamespaceCatalog.namespaceExists(NS)) + .as("Namespace should not already exist") + .isFalse(); + + sql("CREATE NAMESPACE %s", fullNamespace); + sql("CREATE TABLE %s.table (id bigint) USING iceberg", fullNamespace); + + assertThat(validationNamespaceCatalog.namespaceExists(NS)) + .as("Namespace should exist") + .isTrue(); + assertThat(validationCatalog.tableExists(TableIdentifier.of(NS, "table"))) + .as("Table should exist") + .isTrue(); + + assertThatThrownBy(() -> sql("DROP NAMESPACE %s", fullNamespace)) + .isInstanceOf(NamespaceNotEmptyException.class) + .hasMessageStartingWith("Namespace db is not empty."); + + sql("DROP TABLE %s.table", fullNamespace); + } + + @TestTemplate + public void testListTables() { + assertThat(validationNamespaceCatalog.namespaceExists(NS)) + .as("Namespace should not already exist") + .isFalse(); + + sql("CREATE NAMESPACE %s", fullNamespace); + + assertThat(validationNamespaceCatalog.namespaceExists(NS)) + .as("Namespace should exist") + .isTrue(); + + List rows = sql("SHOW TABLES IN %s", fullNamespace); + assertThat(rows).as("Should not list any tables").hasSize(0); + + sql("CREATE TABLE %s.table (id bigint) USING iceberg", fullNamespace); + + Object[] row = Iterables.getOnlyElement(sql("SHOW TABLES IN %s", fullNamespace)); + assertThat(row[0]).as("Namespace should match").isEqualTo("db"); + assertThat(row[1]).as("Table name should match").isEqualTo("table"); + } + + @TestTemplate + public void testListNamespace() { + assertThat(validationNamespaceCatalog.namespaceExists(NS)) + .as("Namespace should not already exist") + .isFalse(); + + sql("CREATE NAMESPACE %s", fullNamespace); + + assertThat(validationNamespaceCatalog.namespaceExists(NS)) + .as("Namespace should exist") + .isTrue(); + + List namespaces = sql("SHOW NAMESPACES IN %s", catalogName); + + if (isHadoopCatalog) { + assertThat(namespaces).as("Should have 1 namespace").hasSize(1); + Set namespaceNames = + namespaces.stream().map(arr -> arr[0].toString()).collect(Collectors.toSet()); + assertThat(namespaceNames) + .as("Should have only db namespace") + .isEqualTo(ImmutableSet.of("db")); + } else { + assertThat(namespaces).as("Should have 2 namespaces").hasSize(2); + Set namespaceNames = + namespaces.stream().map(arr -> arr[0].toString()).collect(Collectors.toSet()); + assertThat(namespaceNames) + .as("Should have default and db namespaces") + .isEqualTo(ImmutableSet.of("default", "db")); + } + + List nestedNamespaces = sql("SHOW NAMESPACES IN %s", fullNamespace); + + Set nestedNames = + nestedNamespaces.stream().map(arr -> arr[0].toString()).collect(Collectors.toSet()); + assertThat(nestedNames).as("Should not have nested namespaces").isEmpty(); + } + + @TestTemplate + public void testCreateNamespaceWithMetadata() { + assumeThat(isHadoopCatalog).as("HadoopCatalog does not support namespace metadata").isFalse(); + + assertThat(validationNamespaceCatalog.namespaceExists(NS)) + .as("Namespace should not already exist") + .isFalse(); + + sql("CREATE NAMESPACE %s WITH PROPERTIES ('prop'='value')", fullNamespace); + + assertThat(validationNamespaceCatalog.namespaceExists(NS)) + .as("Namespace should exist") + .isTrue(); + + Map nsMetadata = validationNamespaceCatalog.loadNamespaceMetadata(NS); + + assertThat(nsMetadata).containsEntry("prop", "value"); + } + + @TestTemplate + public void testCreateNamespaceWithComment() { + assumeThat(isHadoopCatalog).as("HadoopCatalog does not support namespace metadata").isFalse(); + + assertThat(validationNamespaceCatalog.namespaceExists(NS)) + .as("Namespace should not already exist") + .isFalse(); + + sql("CREATE NAMESPACE %s COMMENT 'namespace doc'", fullNamespace); + + assertThat(validationNamespaceCatalog.namespaceExists(NS)) + .as("Namespace should exist") + .isTrue(); + + Map nsMetadata = validationNamespaceCatalog.loadNamespaceMetadata(NS); + + assertThat(nsMetadata).containsEntry("comment", "namespace doc"); + } + + @TestTemplate + public void testCreateNamespaceWithLocation() throws Exception { + assumeThat(isHadoopCatalog).as("HadoopCatalog does not support namespace metadata").isFalse(); + + assertThat(validationNamespaceCatalog.namespaceExists(NS)) + .as("Namespace should not already exist") + .isFalse(); + + File location = File.createTempFile("junit", null, temp.toFile()); + assertThat(location.delete()).isTrue(); + + sql("CREATE NAMESPACE %s LOCATION '%s'", fullNamespace, location); + + assertThat(validationNamespaceCatalog.namespaceExists(NS)) + .as("Namespace should exist") + .isTrue(); + + Map nsMetadata = validationNamespaceCatalog.loadNamespaceMetadata(NS); + + assertThat(nsMetadata).containsEntry("location", "file:" + location.getPath()); + } + + @TestTemplate + public void testSetProperties() { + assumeThat(isHadoopCatalog).as("HadoopCatalog does not support namespace metadata").isFalse(); + + assertThat(validationNamespaceCatalog.namespaceExists(NS)) + .as("Namespace should not already exist") + .isFalse(); + + sql("CREATE NAMESPACE %s", fullNamespace); + + assertThat(validationNamespaceCatalog.namespaceExists(NS)) + .as("Namespace should exist") + .isTrue(); + + Map defaultMetadata = validationNamespaceCatalog.loadNamespaceMetadata(NS); + assertThat(defaultMetadata) + .as("Default metadata should not have custom property") + .doesNotContainKey("prop"); + + sql("ALTER NAMESPACE %s SET PROPERTIES ('prop'='value')", fullNamespace); + + Map nsMetadata = validationNamespaceCatalog.loadNamespaceMetadata(NS); + + assertThat(nsMetadata).containsEntry("prop", "value"); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWrites.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWrites.java new file mode 100644 index 000000000000..800d17dd4559 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWrites.java @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +public class TestPartitionedWrites extends PartitionedWritesTestBase {} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWritesAsSelect.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWritesAsSelect.java new file mode 100644 index 000000000000..373ca9996efd --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWritesAsSelect.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import java.util.stream.IntStream; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.spark.IcebergSpark; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.TestBaseWithCatalog; +import org.apache.spark.sql.types.DataTypes; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestPartitionedWritesAsSelect extends TestBaseWithCatalog { + + @Parameter(index = 3) + private String targetTable; + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}, targetTable = {3}") + protected static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.HADOOP.catalogName(), + SparkCatalogConfig.HADOOP.implementation(), + SparkCatalogConfig.HADOOP.properties(), + SparkCatalogConfig.HADOOP.catalogName() + ".default.target_table" + }, + }; + } + + @BeforeEach + public void createTables() { + sql( + "CREATE TABLE %s (id bigint, data string, category string, ts timestamp) USING iceberg", + tableName); + } + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS %s", targetTable); + } + + @TestTemplate + public void testInsertAsSelectAppend() { + insertData(3); + List expected = currentData(); + + sql( + "CREATE TABLE %s (id bigint, data string, category string, ts timestamp)" + + "USING iceberg PARTITIONED BY (days(ts), category)", + targetTable); + + sql( + "INSERT INTO %s SELECT id, data, category, ts FROM %s ORDER BY ts,category", + targetTable, tableName); + assertThat(scalarSql("SELECT count(*) FROM %s", targetTable)) + .as("Should have 15 rows after insert") + .isEqualTo(3 * 5L); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", targetTable)); + } + + @TestTemplate + public void testInsertAsSelectWithBucket() { + insertData(3); + List expected = currentData(); + + sql( + "CREATE TABLE %s (id bigint, data string, category string, ts timestamp)" + + "USING iceberg PARTITIONED BY (bucket(8, data))", + targetTable); + + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket8", DataTypes.StringType, 8); + sql( + "INSERT INTO %s SELECT id, data, category, ts FROM %s ORDER BY iceberg_bucket8(data)", + targetTable, tableName); + assertThat(scalarSql("SELECT count(*) FROM %s", targetTable)) + .as("Should have 15 rows after insert") + .isEqualTo(3 * 5L); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", targetTable)); + } + + @TestTemplate + public void testInsertAsSelectWithTruncate() { + insertData(3); + List expected = currentData(); + + sql( + "CREATE TABLE %s (id bigint, data string, category string, ts timestamp)" + + "USING iceberg PARTITIONED BY (truncate(data, 4), truncate(id, 4))", + targetTable); + + IcebergSpark.registerTruncateUDF(spark, "iceberg_truncate_string4", DataTypes.StringType, 4); + IcebergSpark.registerTruncateUDF(spark, "iceberg_truncate_long4", DataTypes.LongType, 4); + sql( + "INSERT INTO %s SELECT id, data, category, ts FROM %s " + + "ORDER BY iceberg_truncate_string4(data),iceberg_truncate_long4(id)", + targetTable, tableName); + assertThat(scalarSql("SELECT count(*) FROM %s", targetTable)) + .as("Should have 15 rows after insert") + .isEqualTo(3 * 5L); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", targetTable)); + } + + private void insertData(int repeatCounter) { + IntStream.range(0, repeatCounter) + .forEach( + i -> { + sql( + "INSERT INTO %s VALUES (13, '1', 'bgd16', timestamp('2021-11-10 11:20:10'))," + + "(21, '2', 'bgd13', timestamp('2021-11-10 11:20:10')), " + + "(12, '3', 'bgd14', timestamp('2021-11-10 11:20:10'))," + + "(222, '3', 'bgd15', timestamp('2021-11-10 11:20:10'))," + + "(45, '4', 'bgd16', timestamp('2021-11-10 11:20:10'))", + tableName); + }); + } + + private List currentData() { + return rowsToJava(spark.sql("SELECT * FROM " + tableName + " order by id").collectAsList()); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWritesToBranch.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWritesToBranch.java new file mode 100644 index 000000000000..154c6181a594 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWritesToBranch.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import org.apache.iceberg.Table; +import org.junit.jupiter.api.BeforeEach; + +public class TestPartitionedWritesToBranch extends PartitionedWritesTestBase { + + private static final String BRANCH = "test"; + + @BeforeEach + @Override + public void createTables() { + super.createTables(); + Table table = validationCatalog.loadTable(tableIdent); + table.manageSnapshots().createBranch(BRANCH, table.currentSnapshot().snapshotId()).commit(); + sql("REFRESH TABLE " + tableName); + } + + @Override + protected String commitTarget() { + return String.format("%s.branch_%s", tableName, BRANCH); + } + + @Override + protected String selectTarget() { + return String.format("%s VERSION AS OF '%s'", tableName, BRANCH); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWritesToWapBranch.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWritesToWapBranch.java new file mode 100644 index 000000000000..45c0eb763653 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWritesToWapBranch.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.List; +import java.util.UUID; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestPartitionedWritesToWapBranch extends PartitionedWritesTestBase { + + private static final String BRANCH = "test"; + + @BeforeEach + @Override + public void createTables() { + spark.conf().set(SparkSQLProperties.WAP_BRANCH, BRANCH); + sql( + "CREATE TABLE %s (id bigint, data string) USING iceberg PARTITIONED BY (truncate(id, 3)) OPTIONS (%s = 'true')", + tableName, TableProperties.WRITE_AUDIT_PUBLISH_ENABLED); + sql("INSERT INTO %s VALUES (1, 'a'), (2, 'b'), (3, 'c')", tableName); + } + + @AfterEach + @Override + public void removeTables() { + super.removeTables(); + spark.conf().unset(SparkSQLProperties.WAP_BRANCH); + spark.conf().unset(SparkSQLProperties.WAP_ID); + } + + @Override + protected String commitTarget() { + return tableName; + } + + @Override + protected String selectTarget() { + return String.format("%s VERSION AS OF '%s'", tableName, BRANCH); + } + + @TestTemplate + public void testBranchAndWapBranchCannotBothBeSetForWrite() { + Table table = validationCatalog.loadTable(tableIdent); + table.manageSnapshots().createBranch("test2", table.refs().get(BRANCH).snapshotId()).commit(); + sql("REFRESH TABLE " + tableName); + assertThatThrownBy(() -> sql("INSERT INTO %s.branch_test2 VALUES (4, 'd')", tableName)) + .isInstanceOf(ValidationException.class) + .hasMessage( + "Cannot write to both branch and WAP branch, but got branch [test2] and WAP branch [%s]", + BRANCH); + } + + @TestTemplate + public void testWapIdAndWapBranchCannotBothBeSetForWrite() { + String wapId = UUID.randomUUID().toString(); + spark.conf().set(SparkSQLProperties.WAP_ID, wapId); + assertThatThrownBy(() -> sql("INSERT INTO %s VALUES (4, 'd')", tableName)) + .isInstanceOf(ValidationException.class) + .hasMessage( + "Cannot set both WAP ID and branch, but got ID [%s] and branch [%s]", wapId, BRANCH); + } + + @Override + protected void assertPartitionMetadata( + String tableName, List expected, String... selectPartitionColumns) { + // Cannot read from the .partitions table newly written data into the WAP branch. See + // https://github.com/apache/iceberg/issues/7297 for more details. + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestRefreshTable.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestRefreshTable.java new file mode 100644 index 000000000000..8a9ae0f6030a --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestRefreshTable.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import java.util.List; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.CatalogTestBase; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestRefreshTable extends CatalogTestBase { + + @BeforeEach + public void createTables() { + sql("CREATE TABLE %s (key int, value int) USING iceberg", tableName); + sql("INSERT INTO %s VALUES (1,1)", tableName); + } + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testRefreshCommand() { + // We are not allowed to change the session catalog after it has been initialized, so build a + // new one + if (catalogName.equals(SparkCatalogConfig.SPARK.catalogName()) + || catalogName.equals(SparkCatalogConfig.HADOOP.catalogName())) { + spark.conf().set("spark.sql.catalog." + catalogName + ".cache-enabled", true); + spark = spark.cloneSession(); + } + + List originalExpected = ImmutableList.of(row(1, 1)); + List originalActual = sql("SELECT * FROM %s", tableName); + assertEquals("Table should start as expected", originalExpected, originalActual); + + // Modify table outside of spark, it should be cached so Spark should see the same value after + // mutation + Table table = validationCatalog.loadTable(tableIdent); + DataFile file = table.currentSnapshot().addedDataFiles(table.io()).iterator().next(); + table.newDelete().deleteFile(file).commit(); + + List cachedActual = sql("SELECT * FROM %s", tableName); + assertEquals("Cached table should be unchanged", originalExpected, cachedActual); + + // Refresh the Spark catalog, should be empty + sql("REFRESH TABLE %s", tableName); + List refreshedExpected = ImmutableList.of(); + List refreshedActual = sql("SELECT * FROM %s", tableName); + assertEquals("Refreshed table should be empty", refreshedExpected, refreshedActual); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestSelect.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestSelect.java new file mode 100644 index 000000000000..3ecfc60b49b4 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestSelect.java @@ -0,0 +1,600 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assumptions.assumeThat; + +import java.text.SimpleDateFormat; +import java.util.Arrays; +import java.util.Date; +import java.util.List; +import java.util.concurrent.TimeUnit; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.Table; +import org.apache.iceberg.events.Listeners; +import org.apache.iceberg.events.ScanEvent; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.CatalogTestBase; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestSelect extends CatalogTestBase { + private int scanEventCount = 0; + private ScanEvent lastScanEvent = null; + + @Parameter(index = 3) + private String binaryTableName; + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}, binaryTableName = {3}") + protected static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.HIVE.catalogName(), + SparkCatalogConfig.HIVE.implementation(), + SparkCatalogConfig.HIVE.properties(), + SparkCatalogConfig.HIVE.catalogName() + ".default.binary_table" + }, + { + SparkCatalogConfig.HADOOP.catalogName(), + SparkCatalogConfig.HADOOP.implementation(), + SparkCatalogConfig.HADOOP.properties(), + SparkCatalogConfig.HADOOP.catalogName() + ".default.binary_table" + }, + { + SparkCatalogConfig.SPARK.catalogName(), + SparkCatalogConfig.SPARK.implementation(), + SparkCatalogConfig.SPARK.properties(), + "default.binary_table" + } + }; + } + + @BeforeEach + public void createTables() { + // register a scan event listener to validate pushdown + Listeners.register( + event -> { + scanEventCount += 1; + lastScanEvent = event; + }, + ScanEvent.class); + + sql("CREATE TABLE %s (id bigint, data string, float float) USING iceberg", tableName); + sql("INSERT INTO %s VALUES (1, 'a', 1.0), (2, 'b', 2.0), (3, 'c', float('NaN'))", tableName); + + this.scanEventCount = 0; + this.lastScanEvent = null; + } + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS %s", binaryTableName); + } + + @TestTemplate + public void testSelect() { + List expected = + ImmutableList.of(row(1L, "a", 1.0F), row(2L, "b", 2.0F), row(3L, "c", Float.NaN)); + + assertEquals("Should return all expected rows", expected, sql("SELECT * FROM %s", tableName)); + } + + @TestTemplate + public void testSelectRewrite() { + List expected = ImmutableList.of(row(3L, "c", Float.NaN)); + + assertEquals( + "Should return all expected rows", + expected, + sql("SELECT * FROM %s where float = float('NaN')", tableName)); + + assertThat(scanEventCount).as("Should create only one scan").isEqualTo(1); + assertThat(Spark3Util.describe(lastScanEvent.filter())) + .as("Should push down expected filter") + .isEqualTo("(float IS NOT NULL AND is_nan(float))"); + } + + @TestTemplate + public void testProjection() { + List expected = ImmutableList.of(row(1L), row(2L), row(3L)); + + assertEquals("Should return all expected rows", expected, sql("SELECT id FROM %s", tableName)); + + assertThat(scanEventCount).as("Should create only one scan").isEqualTo(1); + assertThat(lastScanEvent.filter()) + .as("Should not push down a filter") + .isEqualTo(Expressions.alwaysTrue()); + assertThat(lastScanEvent.projection().asStruct()) + .as("Should project only the id column") + .isEqualTo(validationCatalog.loadTable(tableIdent).schema().select("id").asStruct()); + } + + @TestTemplate + public void testExpressionPushdown() { + List expected = ImmutableList.of(row("b")); + + assertEquals( + "Should return all expected rows", + expected, + sql("SELECT data FROM %s WHERE id = 2", tableName)); + + assertThat(scanEventCount).as("Should create only one scan").isEqualTo(1); + assertThat(Spark3Util.describe(lastScanEvent.filter())) + .as("Should push down expected filter") + .isEqualTo("(id IS NOT NULL AND id = 2)"); + assertThat(lastScanEvent.projection().asStruct()) + .as("Should project only id and data columns") + .isEqualTo( + validationCatalog.loadTable(tableIdent).schema().select("id", "data").asStruct()); + } + + @TestTemplate + public void testMetadataTables() { + assumeThat(catalogName) + .as("Spark session catalog does not support metadata tables") + .isNotEqualTo("spark_catalog"); + + assertEquals( + "Snapshot metadata table", + ImmutableList.of(row(ANY, ANY, null, "append", ANY, ANY)), + sql("SELECT * FROM %s.snapshots", tableName)); + } + + @TestTemplate + public void testSnapshotInTableName() { + assumeThat(catalogName) + .as("Spark session catalog does not support extended table names") + .isNotEqualTo("spark_catalog"); + + // get the snapshot ID of the last write and get the current row set as expected + long snapshotId = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); + List expected = sql("SELECT * FROM %s", tableName); + + // create a second snapshot + sql("INSERT INTO %s VALUES (4, 'd', 4.0), (5, 'e', 5.0)", tableName); + + String prefix = "snapshot_id_"; + // read the table at the snapshot + List actual = sql("SELECT * FROM %s.%s", tableName, prefix + snapshotId); + assertEquals("Snapshot at specific ID, prefix " + prefix, expected, actual); + + // read the table using DataFrameReader option + Dataset df = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SNAPSHOT_ID, snapshotId) + .load(tableName); + List fromDF = rowsToJava(df.collectAsList()); + assertEquals("Snapshot at specific ID " + snapshotId, expected, fromDF); + } + + @TestTemplate + public void testTimestampInTableName() { + assumeThat(catalogName) + .as("Spark session catalog does not support extended table names") + .isNotEqualTo("spark_catalog"); + + // get a timestamp just after the last write and get the current row set as expected + long snapshotTs = validationCatalog.loadTable(tableIdent).currentSnapshot().timestampMillis(); + long timestamp = waitUntilAfter(snapshotTs + 2); + List expected = sql("SELECT * FROM %s", tableName); + + // create a second snapshot + sql("INSERT INTO %s VALUES (4, 'd', 4.0), (5, 'e', 5.0)", tableName); + + String prefix = "at_timestamp_"; + // read the table at the snapshot + List actual = sql("SELECT * FROM %s.%s", tableName, prefix + timestamp); + assertEquals("Snapshot at timestamp, prefix " + prefix, expected, actual); + + // read the table using DataFrameReader option + Dataset df = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.AS_OF_TIMESTAMP, timestamp) + .load(tableName); + List fromDF = rowsToJava(df.collectAsList()); + assertEquals("Snapshot at timestamp " + timestamp, expected, fromDF); + } + + @TestTemplate + public void testVersionAsOf() { + // get the snapshot ID of the last write and get the current row set as expected + long snapshotId = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); + List expected = sql("SELECT * FROM %s", tableName); + + // create a second snapshot + sql("INSERT INTO %s VALUES (4, 'd', 4.0), (5, 'e', 5.0)", tableName); + + // read the table at the snapshot + List actual1 = sql("SELECT * FROM %s VERSION AS OF %s", tableName, snapshotId); + assertEquals("Snapshot at specific ID", expected, actual1); + + // read the table at the snapshot + // HIVE time travel syntax + List actual2 = + sql("SELECT * FROM %s FOR SYSTEM_VERSION AS OF %s", tableName, snapshotId); + assertEquals("Snapshot at specific ID", expected, actual2); + + // read the table using DataFrameReader option: versionAsOf + Dataset df = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VERSION_AS_OF, snapshotId) + .load(tableName); + List fromDF = rowsToJava(df.collectAsList()); + assertEquals("Snapshot at specific ID " + snapshotId, expected, fromDF); + } + + @TestTemplate + public void testTagReference() { + Table table = validationCatalog.loadTable(tableIdent); + long snapshotId = table.currentSnapshot().snapshotId(); + table.manageSnapshots().createTag("test_tag", snapshotId).commit(); + List expected = sql("SELECT * FROM %s", tableName); + + // create a second snapshot, read the table at the tag + sql("INSERT INTO %s VALUES (4, 'd', 4.0), (5, 'e', 5.0)", tableName); + List actual1 = sql("SELECT * FROM %s VERSION AS OF 'test_tag'", tableName); + assertEquals("Snapshot at specific tag reference name", expected, actual1); + + // read the table at the tag + // HIVE time travel syntax + List actual2 = sql("SELECT * FROM %s FOR SYSTEM_VERSION AS OF 'test_tag'", tableName); + assertEquals("Snapshot at specific tag reference name", expected, actual2); + + // Spark session catalog does not support extended table names + if (!"spark_catalog".equals(catalogName)) { + // read the table using the "tag_" prefix in the table name + List actual3 = sql("SELECT * FROM %s.tag_test_tag", tableName); + assertEquals("Snapshot at specific tag reference name, prefix", expected, actual3); + } + + // read the table using DataFrameReader option: tag + Dataset df = + spark.read().format("iceberg").option(SparkReadOptions.TAG, "test_tag").load(tableName); + List fromDF = rowsToJava(df.collectAsList()); + assertEquals("Snapshot at specific tag reference name", expected, fromDF); + } + + @TestTemplate + public void testUseSnapshotIdForTagReferenceAsOf() { + Table table = validationCatalog.loadTable(tableIdent); + long snapshotId1 = table.currentSnapshot().snapshotId(); + + // create a second snapshot, read the table at the snapshot + List actual = sql("SELECT * FROM %s", tableName); + sql("INSERT INTO %s VALUES (4, 'd', 4.0), (5, 'e', 5.0)", tableName); + + table.refresh(); + long snapshotId2 = table.currentSnapshot().snapshotId(); + table.manageSnapshots().createTag(Long.toString(snapshotId1), snapshotId2).commit(); + + // currently Spark version travel ignores the type of the AS OF + // this means if a tag name matches a snapshot ID, it will always choose snapshotID to travel + // to. + List travelWithStringResult = + sql("SELECT * FROM %s VERSION AS OF '%s'", tableName, snapshotId1); + assertEquals("Snapshot at specific tag reference name", actual, travelWithStringResult); + + List travelWithLongResult = + sql("SELECT * FROM %s VERSION AS OF %s", tableName, snapshotId1); + assertEquals("Snapshot at specific tag reference name", actual, travelWithLongResult); + } + + @TestTemplate + public void testBranchReference() { + Table table = validationCatalog.loadTable(tableIdent); + long snapshotId = table.currentSnapshot().snapshotId(); + table.manageSnapshots().createBranch("test_branch", snapshotId).commit(); + List expected = sql("SELECT * FROM %s", tableName); + + // create a second snapshot, read the table at the branch + sql("INSERT INTO %s VALUES (4, 'd', 4.0), (5, 'e', 5.0)", tableName); + List actual1 = sql("SELECT * FROM %s VERSION AS OF 'test_branch'", tableName); + assertEquals("Snapshot at specific branch reference name", expected, actual1); + + // read the table at the branch + // HIVE time travel syntax + List actual2 = + sql("SELECT * FROM %s FOR SYSTEM_VERSION AS OF 'test_branch'", tableName); + assertEquals("Snapshot at specific branch reference name", expected, actual2); + + // Spark session catalog does not support extended table names + if (!"spark_catalog".equals(catalogName)) { + // read the table using the "branch_" prefix in the table name + List actual3 = sql("SELECT * FROM %s.branch_test_branch", tableName); + assertEquals("Snapshot at specific branch reference name, prefix", expected, actual3); + } + + // read the table using DataFrameReader option: branch + Dataset df = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.BRANCH, "test_branch") + .load(tableName); + List fromDF = rowsToJava(df.collectAsList()); + assertEquals("Snapshot at specific branch reference name", expected, fromDF); + } + + @TestTemplate + public void readAndWriteWithBranchAfterSchemaChange() { + Table table = validationCatalog.loadTable(tableIdent); + String branchName = "test_branch"; + table.manageSnapshots().createBranch(branchName, table.currentSnapshot().snapshotId()).commit(); + + List expected = + Arrays.asList(row(1L, "a", 1.0f), row(2L, "b", 2.0f), row(3L, "c", Float.NaN)); + assertThat(sql("SELECT * FROM %s", tableName)).containsExactlyElementsOf(expected); + + // change schema on the table and add more data + sql("ALTER TABLE %s DROP COLUMN float", tableName); + sql("ALTER TABLE %s ADD COLUMN new_col date", tableName); + sql( + "INSERT INTO %s VALUES (4, 'd', date('2024-04-04')), (5, 'e', date('2024-05-05'))", + tableName); + + // time-travel query using snapshot id should return the snapshot's schema + long branchSnapshotId = table.refs().get(branchName).snapshotId(); + assertThat(sql("SELECT * FROM %s VERSION AS OF %s", tableName, branchSnapshotId)) + .containsExactlyElementsOf(expected); + + // querying the head of the branch should return the table's schema + assertThat(sql("SELECT * FROM %s VERSION AS OF '%s'", tableName, branchName)) + .containsExactly(row(1L, "a", null), row(2L, "b", null), row(3L, "c", null)); + + if (!"spark_catalog".equals(catalogName)) { + // querying the head of the branch using 'branch_' should return the table's schema + assertThat(sql("SELECT * FROM %s.branch_%s", tableName, branchName)) + .containsExactly(row(1L, "a", null), row(2L, "b", null), row(3L, "c", null)); + } + + // writing to a branch uses the table's schema + sql( + "INSERT INTO %s.branch_%s VALUES (6L, 'f', cast('2023-06-06' as date)), (7L, 'g', cast('2023-07-07' as date))", + tableName, branchName); + + // querying the head of the branch returns the table's schema + assertThat(sql("SELECT * FROM %s VERSION AS OF '%s'", tableName, branchName)) + .containsExactlyInAnyOrder( + row(1L, "a", null), + row(2L, "b", null), + row(3L, "c", null), + row(6L, "f", java.sql.Date.valueOf("2023-06-06")), + row(7L, "g", java.sql.Date.valueOf("2023-07-07"))); + + // using DataFrameReader with the 'branch' option should return the table's schema + Dataset df = + spark.read().format("iceberg").option(SparkReadOptions.BRANCH, branchName).load(tableName); + assertThat(rowsToJava(df.collectAsList())) + .containsExactlyInAnyOrder( + row(1L, "a", null), + row(2L, "b", null), + row(3L, "c", null), + row(6L, "f", java.sql.Date.valueOf("2023-06-06")), + row(7L, "g", java.sql.Date.valueOf("2023-07-07"))); + } + + @TestTemplate + public void testUnknownReferenceAsOf() { + assertThatThrownBy(() -> sql("SELECT * FROM %s VERSION AS OF 'test_unknown'", tableName)) + .hasMessageContaining("Cannot find matching snapshot ID or reference name for version") + .isInstanceOf(ValidationException.class); + } + + @TestTemplate + public void testTimestampAsOf() { + long snapshotTs = validationCatalog.loadTable(tableIdent).currentSnapshot().timestampMillis(); + long timestamp = waitUntilAfter(snapshotTs + 1000); + waitUntilAfter(timestamp + 1000); + // AS OF expects the timestamp if given in long format will be of seconds precision + long timestampInSeconds = TimeUnit.MILLISECONDS.toSeconds(timestamp); + SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); + String formattedDate = sdf.format(new Date(timestamp)); + + List expected = sql("SELECT * FROM %s", tableName); + + // create a second snapshot + sql("INSERT INTO %s VALUES (4, 'd', 4.0), (5, 'e', 5.0)", tableName); + + // read the table at the timestamp in long format i.e 1656507980463. + List actualWithLongFormat = + sql("SELECT * FROM %s TIMESTAMP AS OF %s", tableName, timestampInSeconds); + assertEquals("Snapshot at timestamp", expected, actualWithLongFormat); + + // read the table at the timestamp in date format i.e 2022-06-29 18:40:37 + List actualWithDateFormat = + sql("SELECT * FROM %s TIMESTAMP AS OF '%s'", tableName, formattedDate); + assertEquals("Snapshot at timestamp", expected, actualWithDateFormat); + + // HIVE time travel syntax + // read the table at the timestamp in long format i.e 1656507980463. + List actualWithLongFormatInHiveSyntax = + sql("SELECT * FROM %s FOR SYSTEM_TIME AS OF %s", tableName, timestampInSeconds); + assertEquals("Snapshot at specific ID", expected, actualWithLongFormatInHiveSyntax); + + // read the table at the timestamp in date format i.e 2022-06-29 18:40:37 + List actualWithDateFormatInHiveSyntax = + sql("SELECT * FROM %s FOR SYSTEM_TIME AS OF '%s'", tableName, formattedDate); + assertEquals("Snapshot at specific ID", expected, actualWithDateFormatInHiveSyntax); + + // read the table using DataFrameReader option + Dataset df = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.TIMESTAMP_AS_OF, formattedDate) + .load(tableName); + List fromDF = rowsToJava(df.collectAsList()); + assertEquals("Snapshot at timestamp " + timestamp, expected, fromDF); + } + + @TestTemplate + public void testInvalidTimeTravelBasedOnBothAsOfAndTableIdentifier() { + // get the snapshot ID of the last write + long snapshotId = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); + // get a timestamp just after the last write + long timestamp = + validationCatalog.loadTable(tableIdent).currentSnapshot().timestampMillis() + 2; + + String timestampPrefix = "at_timestamp_"; + String snapshotPrefix = "snapshot_id_"; + + // create a second snapshot + sql("INSERT INTO %s VALUES (4, 'd', 4.0), (5, 'e', 5.0)", tableName); + + // using snapshot in table identifier and VERSION AS OF + assertThatThrownBy( + () -> { + sql( + "SELECT * FROM %s.%s VERSION AS OF %s", + tableName, snapshotPrefix + snapshotId, snapshotId); + }) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot do time-travel based on both table identifier and AS OF"); + + // using snapshot in table identifier and TIMESTAMP AS OF + assertThatThrownBy( + () -> { + sql( + "SELECT * FROM %s.%s VERSION AS OF %s", + tableName, timestampPrefix + timestamp, snapshotId); + }) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot do time-travel based on both table identifier and AS OF"); + + // using timestamp in table identifier and VERSION AS OF + assertThatThrownBy( + () -> { + sql( + "SELECT * FROM %s.%s TIMESTAMP AS OF %s", + tableName, snapshotPrefix + snapshotId, timestamp); + }) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot do time-travel based on both table identifier and AS OF"); + + // using timestamp in table identifier and TIMESTAMP AS OF + assertThatThrownBy( + () -> { + sql( + "SELECT * FROM %s.%s TIMESTAMP AS OF %s", + tableName, timestampPrefix + timestamp, timestamp); + }) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot do time-travel based on both table identifier and AS OF"); + } + + @TestTemplate + public void testInvalidTimeTravelAgainstBranchIdentifierWithAsOf() { + long snapshotId = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); + validationCatalog.loadTable(tableIdent).manageSnapshots().createBranch("b1").commit(); + + // create a second snapshot + sql("INSERT INTO %s VALUES (4, 'd', 4.0), (5, 'e', 5.0)", tableName); + + // using branch_b1 in the table identifier and VERSION AS OF + assertThatThrownBy( + () -> sql("SELECT * FROM %s.branch_b1 VERSION AS OF %s", tableName, snapshotId)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot do time-travel based on both table identifier and AS OF"); + + // using branch_b1 in the table identifier and TIMESTAMP AS OF + assertThatThrownBy(() -> sql("SELECT * FROM %s.branch_b1 TIMESTAMP AS OF now()", tableName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot do time-travel based on both table identifier and AS OF"); + } + + @TestTemplate + public void testSpecifySnapshotAndTimestamp() { + // get the snapshot ID of the last write + long snapshotId = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); + // get a timestamp just after the last write + long timestamp = + validationCatalog.loadTable(tableIdent).currentSnapshot().timestampMillis() + 2; + + // create a second snapshot + sql("INSERT INTO %s VALUES (4, 'd', 4.0), (5, 'e', 5.0)", tableName); + + assertThatThrownBy( + () -> { + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SNAPSHOT_ID, snapshotId) + .option(SparkReadOptions.AS_OF_TIMESTAMP, timestamp) + .load(tableName) + .collectAsList(); + }) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageStartingWith( + String.format( + "Can specify only one of snapshot-id (%s), as-of-timestamp (%s)", + snapshotId, timestamp)); + } + + @TestTemplate + public void testBinaryInFilter() { + sql("CREATE TABLE %s (id bigint, binary binary) USING iceberg", binaryTableName); + sql("INSERT INTO %s VALUES (1, X''), (2, X'1111'), (3, X'11')", binaryTableName); + List expected = ImmutableList.of(row(2L, new byte[] {0x11, 0x11})); + + assertEquals( + "Should return all expected rows", + expected, + sql("SELECT id, binary FROM %s where binary > X'11'", binaryTableName)); + } + + @TestTemplate + public void testComplexTypeFilter() { + String complexTypeTableName = tableName("complex_table"); + sql( + "CREATE TABLE %s (id INT, complex STRUCT) USING iceberg", + complexTypeTableName); + sql( + "INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", 3, \"c2\", \"v1\"))", + complexTypeTableName); + sql( + "INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 2, \"c2\", \"v2\"))", + complexTypeTableName); + + List result = + sql( + "SELECT id FROM %s WHERE complex = named_struct(\"c1\", 3, \"c2\", \"v1\")", + complexTypeTableName); + + assertEquals("Should return all expected rows", ImmutableList.of(row(1)), result); + sql("DROP TABLE IF EXISTS %s", complexTypeTableName); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkBucketFunction.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkBucketFunction.java new file mode 100644 index 000000000000..7c1897250b6f --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkBucketFunction.java @@ -0,0 +1,339 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import org.apache.iceberg.expressions.Literal; +import org.apache.iceberg.relocated.com.google.common.io.BaseEncoding; +import org.apache.iceberg.spark.TestBaseWithCatalog; +import org.apache.iceberg.spark.functions.BucketFunction; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.types.DataTypes; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestSparkBucketFunction extends TestBaseWithCatalog { + @BeforeEach + public void useCatalog() { + sql("USE %s", catalogName); + } + + @TestTemplate + public void testSpecValues() { + assertThat(new BucketFunction.BucketInt(DataTypes.IntegerType).hash(34)) + .as("Spec example: hash(34) = 2017239379") + .isEqualTo(2017239379); + + assertThat(new BucketFunction.BucketLong(DataTypes.IntegerType).hash(34L)) + .as("Spec example: hash(34L) = 2017239379") + .isEqualTo(2017239379); + + assertThat( + new BucketFunction.BucketDecimal(DataTypes.createDecimalType(9, 2)) + .hash(new BigDecimal("14.20"))) + .as("Spec example: hash(decimal2(14.20)) = -500754589") + .isEqualTo(-500754589); + + Literal date = Literal.of("2017-11-16").to(Types.DateType.get()); + assertThat(new BucketFunction.BucketInt(DataTypes.DateType).hash(date.value())) + .as("Spec example: hash(2017-11-16) = -653330422") + .isEqualTo(-653330422); + + Literal timestampVal = + Literal.of("2017-11-16T22:31:08").to(Types.TimestampType.withoutZone()); + assertThat(new BucketFunction.BucketLong(DataTypes.TimestampType).hash(timestampVal.value())) + .as("Spec example: hash(2017-11-16T22:31:08) = -2047944441") + .isEqualTo(-2047944441); + + Literal timestampntzVal = + Literal.of("2017-11-16T22:31:08").to(Types.TimestampType.withoutZone()); + assertThat( + new BucketFunction.BucketLong(DataTypes.TimestampNTZType).hash(timestampntzVal.value())) + .as("Spec example: hash(2017-11-16T22:31:08) = -2047944441") + .isEqualTo(-2047944441); + + assertThat(new BucketFunction.BucketString().hash("iceberg")) + .as("Spec example: hash(\"iceberg\") = 1210000089") + .isEqualTo(1210000089); + + ByteBuffer bytes = ByteBuffer.wrap(new byte[] {0, 1, 2, 3}); + assertThat(new BucketFunction.BucketBinary().hash(bytes)) + .as("Spec example: hash([00 01 02 03]) = -188683207") + .isEqualTo(-188683207); + } + + @TestTemplate + public void testBucketIntegers() { + assertThat(scalarSql("SELECT system.bucket(10, 8Y)")) + .as("Byte type should bucket similarly to integer") + .isEqualTo(3); + assertThat(scalarSql("SELECT system.bucket(10, 8S)")) + .as("Short type should bucket similarly to integer") + .isEqualTo(3); + // Integers + assertThat(scalarSql("SELECT system.bucket(10, 8)")).isEqualTo(3); + assertThat(scalarSql("SELECT system.bucket(100, 34)")).isEqualTo(79); + assertThat(scalarSql("SELECT system.bucket(1, CAST(null AS INT))")).isNull(); + } + + @TestTemplate + public void testBucketDates() { + assertThat(scalarSql("SELECT system.bucket(10, date('1970-01-09'))")).isEqualTo(3); + assertThat(scalarSql("SELECT system.bucket(100, date('1970-02-04'))")).isEqualTo(79); + assertThat(scalarSql("SELECT system.bucket(1, CAST(null AS DATE))")).isNull(); + } + + @TestTemplate + public void testBucketLong() { + assertThat(scalarSql("SELECT system.bucket(100, 34L)")).isEqualTo(79); + assertThat(scalarSql("SELECT system.bucket(100, 0L)")).isEqualTo(76); + assertThat(scalarSql("SELECT system.bucket(100, -34L)")).isEqualTo(97); + assertThat(scalarSql("SELECT system.bucket(2, -1L)")).isEqualTo(0); + assertThat(scalarSql("SELECT system.bucket(2, CAST(null AS LONG))")).isNull(); + } + + @TestTemplate + public void testBucketDecimal() { + assertThat(scalarSql("SELECT system.bucket(64, CAST('12.34' as DECIMAL(9, 2)))")).isEqualTo(56); + assertThat(scalarSql("SELECT system.bucket(18, CAST('12.30' as DECIMAL(9, 2)))")).isEqualTo(13); + assertThat(scalarSql("SELECT system.bucket(16, CAST('12.999' as DECIMAL(9, 3)))")).isEqualTo(2); + assertThat(scalarSql("SELECT system.bucket(32, CAST('0.05' as DECIMAL(5, 2)))")).isEqualTo(21); + assertThat(scalarSql("SELECT system.bucket(128, CAST('0.05' as DECIMAL(9, 2)))")).isEqualTo(85); + assertThat(scalarSql("SELECT system.bucket(18, CAST('0.05' as DECIMAL(9, 2)))")).isEqualTo(3); + + assertThat(scalarSql("SELECT system.bucket(2, CAST(null AS decimal))")) + .as("Null input should return null") + .isNull(); + } + + @TestTemplate + public void testBucketTimestamp() { + assertThat(scalarSql("SELECT system.bucket(100, TIMESTAMP '1997-01-01 00:00:00 UTC+00:00')")) + .isEqualTo(99); + assertThat(scalarSql("SELECT system.bucket(100, TIMESTAMP '1997-01-31 09:26:56 UTC+00:00')")) + .isEqualTo(85); + assertThat(scalarSql("SELECT system.bucket(100, TIMESTAMP '2022-08-08 00:00:00 UTC+00:00')")) + .isEqualTo(62); + assertThat(scalarSql("SELECT system.bucket(2, CAST(null AS timestamp))")).isNull(); + } + + @TestTemplate + public void testBucketString() { + assertThat(scalarSql("SELECT system.bucket(5, 'abcdefg')")).isEqualTo(4); + assertThat(scalarSql("SELECT system.bucket(128, 'abc')")).isEqualTo(122); + assertThat(scalarSql("SELECT system.bucket(64, 'abcde')")).isEqualTo(54); + assertThat(scalarSql("SELECT system.bucket(12, '测试')")).isEqualTo(8); + assertThat(scalarSql("SELECT system.bucket(16, '测试raul试测')")).isEqualTo(1); + assertThat(scalarSql("SELECT system.bucket(16, CAST('测试raul试测' AS varchar(8)))")) + .as("Varchar should work like string") + .isEqualTo(1); + assertThat(scalarSql("SELECT system.bucket(16, CAST('测试raul试测' AS char(8)))")) + .as("Char should work like string") + .isEqualTo(1); + assertThat(scalarSql("SELECT system.bucket(16, '')")) + .as("Should not fail on the empty string") + .isEqualTo(0); + assertThat(scalarSql("SELECT system.bucket(16, CAST(null AS string))")) + .as("Null input should return null as output") + .isNull(); + } + + @TestTemplate + public void testBucketBinary() { + assertThat(scalarSql("SELECT system.bucket(10, X'0102030405060708090a0b0c0d0e0f')")) + .isEqualTo(1); + assertThat(scalarSql("SELECT system.bucket(12, %s)", asBytesLiteral("abcdefg"))).isEqualTo(10); + assertThat(scalarSql("SELECT system.bucket(18, %s)", asBytesLiteral("abc\0\0"))).isEqualTo(13); + assertThat(scalarSql("SELECT system.bucket(48, %s)", asBytesLiteral("abc"))).isEqualTo(42); + assertThat(scalarSql("SELECT system.bucket(16, %s)", asBytesLiteral("测试_"))).isEqualTo(3); + + assertThat(scalarSql("SELECT system.bucket(100, CAST(null AS binary))")) + .as("Null input should return null as output") + .isNull(); + } + + @TestTemplate + public void testNumBucketsAcceptsShortAndByte() { + assertThat(scalarSql("SELECT system.bucket(5S, 1L)")) + .as("Short types should be usable for the number of buckets field") + .isEqualTo(1); + + assertThat(scalarSql("SELECT system.bucket(5Y, 1)")) + .as("Byte types should be allowed for the number of buckets field") + .isEqualTo(1); + } + + @TestTemplate + public void testWrongNumberOfArguments() { + assertThatThrownBy(() -> scalarSql("SELECT system.bucket()")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'bucket' cannot process input: (): Wrong number of inputs (expected numBuckets and value)"); + + assertThatThrownBy(() -> scalarSql("SELECT system.bucket(1)")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'bucket' cannot process input: (int): Wrong number of inputs (expected numBuckets and value)"); + + assertThatThrownBy(() -> scalarSql("SELECT system.bucket(1, 1L, 1)")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'bucket' cannot process input: (int, bigint, int): Wrong number of inputs (expected numBuckets and value)"); + } + + @TestTemplate + public void testInvalidTypesCannotBeUsedForNumberOfBuckets() { + assertThatThrownBy(() -> scalarSql("SELECT system.bucket(CAST('12.34' as DECIMAL(9, 2)), 10)")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'bucket' cannot process input: (decimal(9,2), int): Expected number of buckets to be tinyint, shortint or int"); + + assertThatThrownBy(() -> scalarSql("SELECT system.bucket(12L, 10)")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'bucket' cannot process input: (bigint, int): Expected number of buckets to be tinyint, shortint or int"); + + assertThatThrownBy(() -> scalarSql("SELECT system.bucket('5', 10)")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'bucket' cannot process input: (string, int): Expected number of buckets to be tinyint, shortint or int"); + + assertThatThrownBy(() -> scalarSql("SELECT system.bucket(INTERVAL '100-00' YEAR TO MONTH, 10)")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'bucket' cannot process input: (interval year to month, int): Expected number of buckets to be tinyint, shortint or int"); + + assertThatThrownBy( + () -> + scalarSql("SELECT system.bucket(CAST('11 23:4:0' AS INTERVAL DAY TO SECOND), 10)")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'bucket' cannot process input: (interval day to second, int): Expected number of buckets to be tinyint, shortint or int"); + } + + @TestTemplate + public void testInvalidTypesForBucketColumn() { + assertThatThrownBy(() -> scalarSql("SELECT system.bucket(10, cast(12.3456 as float))")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'bucket' cannot process input: (int, float): Expected column to be date, tinyint, smallint, int, bigint, decimal, timestamp, string, or binary"); + + assertThatThrownBy(() -> scalarSql("SELECT system.bucket(10, cast(12.3456 as double))")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'bucket' cannot process input: (int, double): Expected column to be date, tinyint, smallint, int, bigint, decimal, timestamp, string, or binary"); + + assertThatThrownBy(() -> scalarSql("SELECT system.bucket(10, true)")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith("Function 'bucket' cannot process input: (int, boolean)"); + + assertThatThrownBy(() -> scalarSql("SELECT system.bucket(10, map(1, 1))")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith("Function 'bucket' cannot process input: (int, map)"); + + assertThatThrownBy(() -> scalarSql("SELECT system.bucket(10, array(1L))")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith("Function 'bucket' cannot process input: (int, array)"); + + assertThatThrownBy(() -> scalarSql("SELECT system.bucket(10, INTERVAL '100-00' YEAR TO MONTH)")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'bucket' cannot process input: (int, interval year to month)"); + + assertThatThrownBy( + () -> + scalarSql("SELECT system.bucket(10, CAST('11 23:4:0' AS INTERVAL DAY TO SECOND))")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'bucket' cannot process input: (int, interval day to second)"); + } + + @TestTemplate + public void testThatMagicFunctionsAreInvoked() { + // TinyInt + assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.bucket(5, 6Y)")) + .asString() + .isNotNull() + .contains("staticinvoke(class org.apache.iceberg.spark.functions.BucketFunction$BucketInt"); + + // SmallInt + assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.bucket(5, 6S)")) + .asString() + .isNotNull() + .contains("staticinvoke(class org.apache.iceberg.spark.functions.BucketFunction$BucketInt"); + + // Int + assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.bucket(5, 6)")) + .asString() + .isNotNull() + .contains("staticinvoke(class org.apache.iceberg.spark.functions.BucketFunction$BucketInt"); + + // Date + assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.bucket(100, DATE '2022-08-08')")) + .asString() + .isNotNull() + .contains("staticinvoke(class org.apache.iceberg.spark.functions.BucketFunction$BucketInt"); + + // Long + assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.bucket(5, 6L)")) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.BucketFunction$BucketLong"); + + // Timestamp + assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.bucket(100, TIMESTAMP '2022-08-08')")) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.BucketFunction$BucketLong"); + + // String + assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.bucket(5, 'abcdefg')")) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.BucketFunction$BucketString"); + + // Decimal + assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.bucket(5, CAST('12.34' AS DECIMAL))")) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.BucketFunction$BucketDecimal"); + + // Binary + assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.bucket(4, X'0102030405060708')")) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.BucketFunction$BucketBinary"); + } + + private String asBytesLiteral(String value) { + byte[] bytes = value.getBytes(StandardCharsets.UTF_8); + return "X'" + BaseEncoding.base16().encode(bytes) + "'"; + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkDaysFunction.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkDaysFunction.java new file mode 100644 index 000000000000..36cf196351b8 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkDaysFunction.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.sql.Date; +import org.apache.iceberg.spark.TestBaseWithCatalog; +import org.apache.spark.sql.AnalysisException; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestSparkDaysFunction extends TestBaseWithCatalog { + + @BeforeEach + public void useCatalog() { + sql("USE %s", catalogName); + } + + @TestTemplate + public void testDates() { + assertThat(scalarSql("SELECT system.days(date('2017-12-01'))")) + .as("Expected to produce 2017-12-01") + .isEqualTo(Date.valueOf("2017-12-01")); + assertThat(scalarSql("SELECT system.days(date('1970-01-01'))")) + .as("Expected to produce 1970-01-01") + .isEqualTo(Date.valueOf("1970-01-01")); + assertThat(scalarSql("SELECT system.days(date('1969-12-31'))")) + .as("Expected to produce 1969-12-31") + .isEqualTo(Date.valueOf("1969-12-31")); + assertThat(scalarSql("SELECT system.days(CAST(null AS DATE))")).isNull(); + } + + @TestTemplate + public void testTimestamps() { + assertThat(scalarSql("SELECT system.days(TIMESTAMP '2017-12-01 10:12:55.038194 UTC+00:00')")) + .as("Expected to produce 2017-12-01") + .isEqualTo(Date.valueOf("2017-12-01")); + assertThat(scalarSql("SELECT system.days(TIMESTAMP '1970-01-01 00:00:01.000001 UTC+00:00')")) + .as("Expected to produce 1970-01-01") + .isEqualTo(Date.valueOf("1970-01-01")); + assertThat(scalarSql("SELECT system.days(TIMESTAMP '1969-12-31 23:59:58.999999 UTC+00:00')")) + .as("Expected to produce 1969-12-31") + .isEqualTo(Date.valueOf("1969-12-31")); + assertThat(scalarSql("SELECT system.days(CAST(null AS TIMESTAMP))")).isNull(); + } + + @TestTemplate + public void testTimestampNtz() { + assertThat(scalarSql("SELECT system.days(TIMESTAMP_NTZ '2017-12-01 10:12:55.038194 UTC')")) + .as("Expected to produce 2017-12-01") + .isEqualTo(Date.valueOf("2017-12-01")); + assertThat(scalarSql("SELECT system.days(TIMESTAMP_NTZ '1970-01-01 00:00:01.000001 UTC')")) + .as("Expected to produce 1970-01-01") + .isEqualTo(Date.valueOf("1970-01-01")); + assertThat(scalarSql("SELECT system.days(TIMESTAMP_NTZ '1969-12-31 23:59:58.999999 UTC')")) + .as("Expected to produce 1969-12-31") + .isEqualTo(Date.valueOf("1969-12-31")); + assertThat(scalarSql("SELECT system.days(CAST(null AS TIMESTAMP_NTZ))")).isNull(); + } + + @TestTemplate + public void testWrongNumberOfArguments() { + assertThatThrownBy(() -> scalarSql("SELECT system.days()")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith("Function 'days' cannot process input: (): Wrong number of inputs"); + + assertThatThrownBy( + () -> scalarSql("SELECT system.days(date('1969-12-31'), date('1969-12-31'))")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'days' cannot process input: (date, date): Wrong number of inputs"); + } + + @TestTemplate + public void testInvalidInputTypes() { + assertThatThrownBy(() -> scalarSql("SELECT system.days(1)")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'days' cannot process input: (int): Expected value to be date or timestamp"); + + assertThatThrownBy(() -> scalarSql("SELECT system.days(1L)")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'days' cannot process input: (bigint): Expected value to be date or timestamp"); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkHoursFunction.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkHoursFunction.java new file mode 100644 index 000000000000..17380747b4c0 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkHoursFunction.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import org.apache.iceberg.spark.TestBaseWithCatalog; +import org.apache.spark.sql.AnalysisException; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestSparkHoursFunction extends TestBaseWithCatalog { + + @BeforeEach + public void useCatalog() { + sql("USE %s", catalogName); + } + + @TestTemplate + public void testTimestamps() { + assertThat(scalarSql("SELECT system.hours(TIMESTAMP '2017-12-01 10:12:55.038194 UTC+00:00')")) + .as("Expected to produce 17501 * 24 + 10") + .isEqualTo(420034); + assertThat(scalarSql("SELECT system.hours(TIMESTAMP '1970-01-01 00:00:01.000001 UTC+00:00')")) + .as("Expected to produce 0 * 24 + 0 = 0") + .isEqualTo(0); + assertThat(scalarSql("SELECT system.hours(TIMESTAMP '1969-12-31 23:59:58.999999 UTC+00:00')")) + .as("Expected to produce -1") + .isEqualTo(-1); + assertThat(scalarSql("SELECT system.hours(CAST(null AS TIMESTAMP))")).isNull(); + } + + @TestTemplate + public void testTimestampsNtz() { + assertThat(scalarSql("SELECT system.hours(TIMESTAMP_NTZ '2017-12-01 10:12:55.038194 UTC')")) + .as("Expected to produce 17501 * 24 + 10") + .isEqualTo(420034); + assertThat(scalarSql("SELECT system.hours(TIMESTAMP_NTZ '1970-01-01 00:00:01.000001 UTC')")) + .as("Expected to produce 0 * 24 + 0 = 0") + .isEqualTo(0); + assertThat(scalarSql("SELECT system.hours(TIMESTAMP_NTZ '1969-12-31 23:59:58.999999 UTC')")) + .as("Expected to produce -1") + .isEqualTo(-1); + assertThat(scalarSql("SELECT system.hours(CAST(null AS TIMESTAMP_NTZ))")).isNull(); + } + + @TestTemplate + public void testWrongNumberOfArguments() { + assertThatThrownBy(() -> scalarSql("SELECT system.hours()")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'hours' cannot process input: (): Wrong number of inputs"); + + assertThatThrownBy( + () -> scalarSql("SELECT system.hours(date('1969-12-31'), date('1969-12-31'))")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'hours' cannot process input: (date, date): Wrong number of inputs"); + } + + @TestTemplate + public void testInvalidInputTypes() { + assertThatThrownBy(() -> scalarSql("SELECT system.hours(1)")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'hours' cannot process input: (int): Expected value to be timestamp"); + + assertThatThrownBy(() -> scalarSql("SELECT system.hours(1L)")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'hours' cannot process input: (bigint): Expected value to be timestamp"); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkMonthsFunction.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkMonthsFunction.java new file mode 100644 index 000000000000..1a00950124f0 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkMonthsFunction.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import org.apache.iceberg.spark.TestBaseWithCatalog; +import org.apache.iceberg.spark.functions.MonthsFunction; +import org.apache.spark.sql.AnalysisException; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestSparkMonthsFunction extends TestBaseWithCatalog { + + @BeforeEach + public void useCatalog() { + sql("USE %s", catalogName); + } + + @TestTemplate + public void testDates() { + assertThat(scalarSql("SELECT system.months(date('2017-12-01'))")) + .as("Expected to produce 47 * 12 + 11 = 575") + .isEqualTo(575); + assertThat(scalarSql("SELECT system.months(date('1970-01-01'))")) + .as("Expected to produce 0 * 12 + 0 = 0") + .isEqualTo(0); + assertThat(scalarSql("SELECT system.months(date('1969-12-31'))")) + .as("Expected to produce -1") + .isEqualTo(-1); + assertThat(scalarSql("SELECT system.months(CAST(null AS DATE))")).isNull(); + } + + @TestTemplate + public void testTimestamps() { + assertThat(scalarSql("SELECT system.months(TIMESTAMP '2017-12-01 10:12:55.038194 UTC+00:00')")) + .as("Expected to produce 47 * 12 + 11 = 575") + .isEqualTo(575); + assertThat(scalarSql("SELECT system.months(TIMESTAMP '1970-01-01 00:00:01.000001 UTC+00:00')")) + .as("Expected to produce 0 * 12 + 0 = 0") + .isEqualTo(0); + assertThat(scalarSql("SELECT system.months(TIMESTAMP '1969-12-31 23:59:58.999999 UTC+00:00')")) + .as("Expected to produce -1") + .isEqualTo(-1); + assertThat(scalarSql("SELECT system.months(CAST(null AS TIMESTAMP))")).isNull(); + } + + @TestTemplate + public void testTimestampNtz() { + assertThat(scalarSql("SELECT system.months(TIMESTAMP_NTZ '2017-12-01 10:12:55.038194 UTC')")) + .as("Expected to produce 47 * 12 + 11 = 575") + .isEqualTo(575); + assertThat(scalarSql("SELECT system.months(TIMESTAMP_NTZ '1970-01-01 00:00:01.000001 UTC')")) + .as("Expected to produce 0 * 12 + 0 = 0") + .isEqualTo(0); + assertThat(scalarSql("SELECT system.months(TIMESTAMP_NTZ '1969-12-31 23:59:58.999999 UTC')")) + .as("Expected to produce -1") + .isEqualTo(-1); + assertThat(scalarSql("SELECT system.months(CAST(null AS TIMESTAMP_NTZ))")).isNull(); + } + + @TestTemplate + public void testWrongNumberOfArguments() { + assertThatThrownBy(() -> scalarSql("SELECT system.months()")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'months' cannot process input: (): Wrong number of inputs"); + + assertThatThrownBy( + () -> scalarSql("SELECT system.months(date('1969-12-31'), date('1969-12-31'))")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'months' cannot process input: (date, date): Wrong number of inputs"); + } + + @TestTemplate + public void testInvalidInputTypes() { + assertThatThrownBy(() -> scalarSql("SELECT system.months(1)")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'months' cannot process input: (int): Expected value to be date or timestamp"); + + assertThatThrownBy(() -> scalarSql("SELECT system.months(1L)")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'months' cannot process input: (bigint): Expected value to be date or timestamp"); + } + + @TestTemplate + public void testThatMagicFunctionsAreInvoked() { + String dateValue = "date('2017-12-01')"; + String dateTransformClass = MonthsFunction.DateToMonthsFunction.class.getName(); + assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.months(%s)", dateValue)) + .asString() + .isNotNull() + .contains("staticinvoke(class " + dateTransformClass); + + String timestampValue = "TIMESTAMP '2017-12-01 10:12:55.038194 UTC+00:00'"; + String timestampTransformClass = MonthsFunction.TimestampToMonthsFunction.class.getName(); + assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.months(%s)", timestampValue)) + .asString() + .isNotNull() + .contains("staticinvoke(class " + timestampTransformClass); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkTruncateFunction.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkTruncateFunction.java new file mode 100644 index 000000000000..25f3770d01e4 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkTruncateFunction.java @@ -0,0 +1,442 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; +import org.apache.iceberg.relocated.com.google.common.io.BaseEncoding; +import org.apache.iceberg.spark.TestBaseWithCatalog; +import org.apache.spark.sql.AnalysisException; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestSparkTruncateFunction extends TestBaseWithCatalog { + + @BeforeEach + public void useCatalog() { + sql("USE %s", catalogName); + } + + @TestTemplate + public void testTruncateTinyInt() { + assertThat(scalarSql("SELECT system.truncate(10, 0Y)")).isEqualTo((byte) 0); + assertThat(scalarSql("SELECT system.truncate(10, 1Y)")).isEqualTo((byte) 0); + assertThat(scalarSql("SELECT system.truncate(10, 5Y)")).isEqualTo((byte) 0); + assertThat(scalarSql("SELECT system.truncate(10, 9Y)")).isEqualTo((byte) 0); + assertThat(scalarSql("SELECT system.truncate(10, 10Y)")).isEqualTo((byte) 10); + assertThat(scalarSql("SELECT system.truncate(10, 11Y)")).isEqualTo((byte) 10); + assertThat(scalarSql("SELECT system.truncate(10, -1Y)")).isEqualTo((byte) -10); + assertThat(scalarSql("SELECT system.truncate(10, -5Y)")).isEqualTo((byte) -10); + assertThat(scalarSql("SELECT system.truncate(10, -10Y)")).isEqualTo((byte) -10); + assertThat(scalarSql("SELECT system.truncate(10, -11Y)")).isEqualTo((byte) -20); + + // Check that different widths can be used + assertThat(scalarSql("SELECT system.truncate(2, -1Y)")).isEqualTo((byte) -2); + + assertThat(scalarSql("SELECT system.truncate(2, CAST(null AS tinyint))")) + .as("Null input should return null") + .isNull(); + } + + @TestTemplate + public void testTruncateSmallInt() { + assertThat(scalarSql("SELECT system.truncate(10, 0S)")).isEqualTo((short) 0); + assertThat(scalarSql("SELECT system.truncate(10, 1S)")).isEqualTo((short) 0); + assertThat(scalarSql("SELECT system.truncate(10, 5S)")).isEqualTo((short) 0); + assertThat(scalarSql("SELECT system.truncate(10, 9S)")).isEqualTo((short) 0); + assertThat(scalarSql("SELECT system.truncate(10, 10S)")).isEqualTo((short) 10); + assertThat(scalarSql("SELECT system.truncate(10, 11S)")).isEqualTo((short) 10); + assertThat(scalarSql("SELECT system.truncate(10, -1S)")).isEqualTo((short) -10); + assertThat(scalarSql("SELECT system.truncate(10, -5S)")).isEqualTo((short) -10); + assertThat(scalarSql("SELECT system.truncate(10, -10S)")).isEqualTo((short) -10); + assertThat(scalarSql("SELECT system.truncate(10, -11S)")).isEqualTo((short) -20); + + // Check that different widths can be used + assertThat(scalarSql("SELECT system.truncate(2, -1S)")).isEqualTo((short) -2); + + assertThat(scalarSql("SELECT system.truncate(2, CAST(null AS smallint))")) + .as("Null input should return null") + .isNull(); + } + + @TestTemplate + public void testTruncateInt() { + assertThat(scalarSql("SELECT system.truncate(10, 0)")).isEqualTo(0); + assertThat(scalarSql("SELECT system.truncate(10, 1)")).isEqualTo(0); + assertThat(scalarSql("SELECT system.truncate(10, 5)")).isEqualTo(0); + assertThat(scalarSql("SELECT system.truncate(10, 9)")).isEqualTo(0); + assertThat(scalarSql("SELECT system.truncate(10, 10)")).isEqualTo(10); + assertThat(scalarSql("SELECT system.truncate(10, 11)")).isEqualTo(10); + assertThat(scalarSql("SELECT system.truncate(10, -1)")).isEqualTo(-10); + assertThat(scalarSql("SELECT system.truncate(10, -5)")).isEqualTo(-10); + assertThat(scalarSql("SELECT system.truncate(10, -10)")).isEqualTo(-10); + assertThat(scalarSql("SELECT system.truncate(10, -11)")).isEqualTo(-20); + + // Check that different widths can be used + assertThat(scalarSql("SELECT system.truncate(2, -1)")).isEqualTo(-2); + assertThat(scalarSql("SELECT system.truncate(300, 1)")).isEqualTo(0); + + assertThat(scalarSql("SELECT system.truncate(2, CAST(null AS int))")) + .as("Null input should return null") + .isNull(); + } + + @TestTemplate + public void testTruncateBigInt() { + assertThat(scalarSql("SELECT system.truncate(10, 0L)")).isEqualTo(0L); + assertThat(scalarSql("SELECT system.truncate(10, 1L)")).isEqualTo(0L); + assertThat(scalarSql("SELECT system.truncate(10, 5L)")).isEqualTo(0L); + assertThat(scalarSql("SELECT system.truncate(10, 9L)")).isEqualTo(0L); + assertThat(scalarSql("SELECT system.truncate(10, 10L)")).isEqualTo(10L); + assertThat(scalarSql("SELECT system.truncate(10, 11L)")).isEqualTo(10L); + assertThat(scalarSql("SELECT system.truncate(10, -1L)")).isEqualTo(-10L); + assertThat(scalarSql("SELECT system.truncate(10, -5L)")).isEqualTo(-10L); + assertThat(scalarSql("SELECT system.truncate(10, -10L)")).isEqualTo(-10L); + assertThat(scalarSql("SELECT system.truncate(10, -11L)")).isEqualTo(-20L); + + // Check that different widths can be used + assertThat(scalarSql("SELECT system.truncate(2, -1L)")).isEqualTo(-2L); + + assertThat(scalarSql("SELECT system.truncate(2, CAST(null AS bigint))")) + .as("Null input should return null") + .isNull(); + } + + @TestTemplate + public void testTruncateDecimal() { + // decimal truncation works by applying the decimal scale to the width: ie 10 scale 2 = 0.10 + assertThat(scalarSql("SELECT system.truncate(10, CAST(%s as DECIMAL(9, 2)))", "12.34")) + .isEqualTo(new BigDecimal("12.30")); + + assertThat(scalarSql("SELECT system.truncate(10, CAST(%s as DECIMAL(9, 2)))", "12.30")) + .isEqualTo(new BigDecimal("12.30")); + + assertThat(scalarSql("SELECT system.truncate(10, CAST(%s as DECIMAL(9, 3)))", "12.299")) + .isEqualTo(new BigDecimal("12.290")); + + assertThat(scalarSql("SELECT system.truncate(3, CAST(%s as DECIMAL(5, 2)))", "0.05")) + .isEqualTo(new BigDecimal("0.03")); + + assertThat(scalarSql("SELECT system.truncate(10, CAST(%s as DECIMAL(9, 2)))", "0.05")) + .isEqualTo(new BigDecimal("0.00")); + + assertThat(scalarSql("SELECT system.truncate(10, CAST(%s as DECIMAL(9, 2)))", "-0.05")) + .isEqualTo(new BigDecimal("-0.10")); + + assertThat(scalarSql("SELECT system.truncate(10, 12345.3482)")) + .as("Implicit decimal scale and precision should be allowed") + .isEqualTo(new BigDecimal("12345.3480")); + + BigDecimal truncatedDecimal = + (BigDecimal) scalarSql("SELECT system.truncate(10, CAST(%s as DECIMAL(6, 4)))", "-0.05"); + assertThat(truncatedDecimal.scale()) + .as("Truncating a decimal should return a decimal with the same scale") + .isEqualTo(4); + + assertThat(truncatedDecimal) + .as("Truncating a decimal should return a decimal with the correct scale") + .isEqualTo(BigDecimal.valueOf(-500, 4)); + + assertThat(scalarSql("SELECT system.truncate(2, CAST(null AS decimal))")) + .as("Null input should return null") + .isNull(); + } + + @SuppressWarnings("checkstyle:AvoidEscapedUnicodeCharacters") + @TestTemplate + public void testTruncateString() { + assertThat(scalarSql("SELECT system.truncate(5, 'abcdefg')")) + .as("Should system.truncate strings longer than length") + .isEqualTo("abcde"); + + assertThat(scalarSql("SELECT system.truncate(5, 'abc')")) + .as("Should not pad strings shorter than length") + .isEqualTo("abc"); + + assertThat(scalarSql("SELECT system.truncate(5, 'abcde')")) + .as("Should not alter strings equal to length") + .isEqualTo("abcde"); + + assertThat(scalarSql("SELECT system.truncate(2, 'イロハニホヘト')")) + .as("Strings with multibyte unicode characters should truncate along codepoint boundaries") + .isEqualTo("イロ"); + + assertThat(scalarSql("SELECT system.truncate(3, 'イロハニホヘト')")) + .as("Strings with multibyte unicode characters should truncate along codepoint boundaries") + .isEqualTo("イロハ"); + + assertThat(scalarSql("SELECT system.truncate(7, 'イロハニホヘト')")) + .as( + "Strings with multibyte unicode characters should not alter input with fewer codepoints than width") + .isEqualTo("イロハニホヘト"); + + String stringWithTwoCodePointsEachFourBytes = "\uD800\uDC00\uD800\uDC00"; + assertThat(scalarSql("SELECT system.truncate(1, '%s')", stringWithTwoCodePointsEachFourBytes)) + .as("String truncation on four byte codepoints should work as expected") + .isEqualTo("\uD800\uDC00"); + + assertThat(scalarSql("SELECT system.truncate(1, '测试')")) + .as("Should handle three-byte UTF-8 characters appropriately") + .isEqualTo("测"); + + assertThat(scalarSql("SELECT system.truncate(4, '测试raul试测')")) + .as("Should handle three-byte UTF-8 characters mixed with two byte utf-8 characters") + .isEqualTo("测试ra"); + + assertThat(scalarSql("SELECT system.truncate(10, '')")) + .as("Should not fail on the empty string") + .isEqualTo(""); + + assertThat(scalarSql("SELECT system.truncate(3, CAST(null AS string))")) + .as("Null input should return null as output") + .isNull(); + + assertThat(scalarSql("SELECT system.truncate(4, CAST('测试raul试测' AS varchar(8)))")) + .as("Varchar should work like string") + .isEqualTo("测试ra"); + + assertThat(scalarSql("SELECT system.truncate(4, CAST('测试raul试测' AS char(8)))")) + .as("Char should work like string") + .isEqualTo("测试ra"); + } + + @TestTemplate + public void testTruncateBinary() { + assertThat((byte[]) scalarSql("SELECT system.truncate(10, X'0102030405060708090a0b0c0d0e0f')")) + .isEqualTo(new byte[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + + assertThat((byte[]) scalarSql("SELECT system.truncate(3, %s)", asBytesLiteral("abcdefg"))) + .as("Should return the same input when value is equal to truncation width") + .isEqualTo("abc".getBytes(StandardCharsets.UTF_8)); + + assertThat((byte[]) scalarSql("SELECT system.truncate(10, %s)", asBytesLiteral("abc\0\0"))) + .as("Should not truncate, pad, or trim the input when its length is less than the width") + .isEqualTo("abc\0\0".getBytes(StandardCharsets.UTF_8)); + + assertThat((byte[]) scalarSql("SELECT system.truncate(3, %s)", asBytesLiteral("abc"))) + .as("Should not pad the input when its length is equal to the width") + .isEqualTo("abc".getBytes(StandardCharsets.UTF_8)); + + assertThat((byte[]) scalarSql("SELECT system.truncate(6, %s)", asBytesLiteral("测试_"))) + .as("Should handle three-byte UTF-8 characters appropriately") + .isEqualTo("测试".getBytes(StandardCharsets.UTF_8)); + + assertThat(scalarSql("SELECT system.truncate(3, CAST(null AS binary))")) + .as("Null input should return null as output") + .isNull(); + } + + @TestTemplate + public void testTruncateUsingDataframeForWidthWithVaryingWidth() { + // This situation is atypical but allowed. Typically, width is static as data is partitioned on + // one width. + long rumRows = 10L; + long numNonZero = + spark + .range(rumRows) + .toDF("value") + .selectExpr("CAST(value + 1 AS INT) AS width", "value") + .selectExpr("system.truncate(width, value) as truncated_value") + .filter("truncated_value == 0") + .count(); + assertThat(numNonZero) + .as("A truncate function with variable widths should be usable on dataframe columns") + .isEqualTo(rumRows); + } + + @TestTemplate + public void testWidthAcceptsShortAndByte() { + assertThat(scalarSql("SELECT system.truncate(5S, 1L)")) + .as("Short types should be usable for the width field") + .isEqualTo(0L); + + assertThat(scalarSql("SELECT system.truncate(5Y, 1)")) + .as("Byte types should be allowed for the width field") + .isEqualTo(0); + } + + @TestTemplate + public void testWrongNumberOfArguments() { + assertThatThrownBy(() -> scalarSql("SELECT system.truncate()")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'truncate' cannot process input: (): Wrong number of inputs (expected width and value)"); + + assertThatThrownBy(() -> scalarSql("SELECT system.truncate(1)")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'truncate' cannot process input: (int): Wrong number of inputs (expected width and value)"); + + assertThatThrownBy(() -> scalarSql("SELECT system.truncate(1, 1L, 1)")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'truncate' cannot process input: (int, bigint, int): Wrong number of inputs (expected width and value)"); + } + + @TestTemplate + public void testInvalidTypesCannotBeUsedForWidth() { + assertThatThrownBy( + () -> scalarSql("SELECT system.truncate(CAST('12.34' as DECIMAL(9, 2)), 10)")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'truncate' cannot process input: (decimal(9,2), int): Expected truncation width to be tinyint, shortint or int"); + + assertThatThrownBy(() -> scalarSql("SELECT system.truncate('5', 10)")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'truncate' cannot process input: (string, int): Expected truncation width to be tinyint, shortint or int"); + + assertThatThrownBy( + () -> scalarSql("SELECT system.truncate(INTERVAL '100-00' YEAR TO MONTH, 10)")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'truncate' cannot process input: (interval year to month, int): Expected truncation width to be tinyint, shortint or int"); + + assertThatThrownBy( + () -> + scalarSql( + "SELECT system.truncate(CAST('11 23:4:0' AS INTERVAL DAY TO SECOND), 10)")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'truncate' cannot process input: (interval day to second, int): Expected truncation width to be tinyint, shortint or int"); + } + + @TestTemplate + public void testInvalidTypesForTruncationColumn() { + assertThatThrownBy(() -> scalarSql("SELECT system.truncate(10, cast(12.3456 as float))")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'truncate' cannot process input: (int, float): Expected truncation col to be tinyint, shortint, int, bigint, decimal, string, or binary"); + + assertThatThrownBy(() -> scalarSql("SELECT system.truncate(10, cast(12.3456 as double))")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'truncate' cannot process input: (int, double): Expected truncation col to be tinyint, shortint, int, bigint, decimal, string, or binary"); + + assertThatThrownBy(() -> scalarSql("SELECT system.truncate(10, true)")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'truncate' cannot process input: (int, boolean): Expected truncation col to be tinyint, shortint, int, bigint, decimal, string, or binary"); + + assertThatThrownBy(() -> scalarSql("SELECT system.truncate(10, map(1, 1))")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'truncate' cannot process input: (int, map): Expected truncation col to be tinyint, shortint, int, bigint, decimal, string, or binary"); + + assertThatThrownBy(() -> scalarSql("SELECT system.truncate(10, array(1L))")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'truncate' cannot process input: (int, array): Expected truncation col to be tinyint, shortint, int, bigint, decimal, string, or binary"); + + assertThatThrownBy( + () -> scalarSql("SELECT system.truncate(10, INTERVAL '100-00' YEAR TO MONTH)")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'truncate' cannot process input: (int, interval year to month): Expected truncation col to be tinyint, shortint, int, bigint, decimal, string, or binary"); + + assertThatThrownBy( + () -> + scalarSql( + "SELECT system.truncate(10, CAST('11 23:4:0' AS INTERVAL DAY TO SECOND))")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'truncate' cannot process input: (int, interval day to second): Expected truncation col to be tinyint, shortint, int, bigint, decimal, string, or binary"); + } + + @TestTemplate + public void testMagicFunctionsResolveForTinyIntAndSmallIntWidths() { + // Magic functions have staticinvoke in the explain output. Nonmagic calls use + // applyfunctionexpression instead. + String tinyIntWidthExplain = + (String) scalarSql("EXPLAIN EXTENDED SELECT system.truncate(1Y, 6)"); + assertThat(tinyIntWidthExplain) + .contains("cast(1 as int)") + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.TruncateFunction$TruncateInt"); + + String smallIntWidth = (String) scalarSql("EXPLAIN EXTENDED SELECT system.truncate(5S, 6L)"); + assertThat(smallIntWidth) + .contains("cast(5 as int)") + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.TruncateFunction$TruncateBigInt"); + } + + @TestTemplate + public void testThatMagicFunctionsAreInvoked() { + // Magic functions have `staticinvoke` in the explain output. + // Non-magic calls have `applyfunctionexpression` instead. + + // TinyInt + assertThat(scalarSql("EXPLAIN EXTENDED select system.truncate(5, 6Y)")) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.TruncateFunction$TruncateTinyInt"); + + // SmallInt + assertThat(scalarSql("EXPLAIN EXTENDED select system.truncate(5, 6S)")) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.TruncateFunction$TruncateSmallInt"); + + // Int + assertThat(scalarSql("EXPLAIN EXTENDED select system.truncate(5, 6)")) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.TruncateFunction$TruncateInt"); + + // Long + assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.truncate(5, 6L)")) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.TruncateFunction$TruncateBigInt"); + + // String + assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.truncate(5, 'abcdefg')")) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.TruncateFunction$TruncateString"); + + // Decimal + assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.truncate(5, 12.34)")) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.TruncateFunction$TruncateDecimal"); + + // Binary + assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.truncate(4, X'0102030405060708')")) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.TruncateFunction$TruncateBinary"); + } + + private String asBytesLiteral(String value) { + byte[] bytes = value.getBytes(StandardCharsets.UTF_8); + return "X'" + BaseEncoding.base16().encode(bytes) + "'"; + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkYearsFunction.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkYearsFunction.java new file mode 100644 index 000000000000..8cf62b2b48f3 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkYearsFunction.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import org.apache.iceberg.spark.TestBaseWithCatalog; +import org.apache.iceberg.spark.functions.YearsFunction; +import org.apache.spark.sql.AnalysisException; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestSparkYearsFunction extends TestBaseWithCatalog { + + @BeforeEach + public void useCatalog() { + sql("USE %s", catalogName); + } + + @TestTemplate + public void testDates() { + assertThat(scalarSql("SELECT system.years(date('2017-12-01'))")) + .as("Expected to produce 2017 - 1970 = 47") + .isEqualTo(47); + assertThat(scalarSql("SELECT system.years(date('1970-01-01'))")) + .as("Expected to produce 1970 - 1970 = 0") + .isEqualTo(0); + assertThat(scalarSql("SELECT system.years(date('1969-12-31'))")) + .as("Expected to produce 1969 - 1970 = -1") + .isEqualTo(-1); + assertThat(scalarSql("SELECT system.years(CAST(null AS DATE))")).isNull(); + } + + @TestTemplate + public void testTimestamps() { + assertThat(scalarSql("SELECT system.years(TIMESTAMP '2017-12-01 10:12:55.038194 UTC+00:00')")) + .as("Expected to produce 2017 - 1970 = 47") + .isEqualTo(47); + assertThat(scalarSql("SELECT system.years(TIMESTAMP '1970-01-01 00:00:01.000001 UTC+00:00')")) + .as("Expected to produce 1970 - 1970 = 0") + .isEqualTo(0); + assertThat(scalarSql("SELECT system.years(TIMESTAMP '1969-12-31 23:59:58.999999 UTC+00:00')")) + .as("Expected to produce 1969 - 1970 = -1") + .isEqualTo(-1); + assertThat(scalarSql("SELECT system.years(CAST(null AS TIMESTAMP))")).isNull(); + } + + @TestTemplate + public void testTimestampNtz() { + assertThat(scalarSql("SELECT system.years(TIMESTAMP_NTZ '2017-12-01 10:12:55.038194 UTC')")) + .as("Expected to produce 2017 - 1970 = 47") + .isEqualTo(47); + assertThat(scalarSql("SELECT system.years(TIMESTAMP_NTZ '1970-01-01 00:00:01.000001 UTC')")) + .as("Expected to produce 1970 - 1970 = 0") + .isEqualTo(0); + assertThat(scalarSql("SELECT system.years(TIMESTAMP_NTZ '1969-12-31 23:59:58.999999 UTC')")) + .as("Expected to produce 1969 - 1970 = -1") + .isEqualTo(-1); + assertThat(scalarSql("SELECT system.years(CAST(null AS TIMESTAMP_NTZ))")).isNull(); + } + + @TestTemplate + public void testWrongNumberOfArguments() { + assertThatThrownBy(() -> scalarSql("SELECT system.years()")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'years' cannot process input: (): Wrong number of inputs"); + + assertThatThrownBy( + () -> scalarSql("SELECT system.years(date('1969-12-31'), date('1969-12-31'))")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'years' cannot process input: (date, date): Wrong number of inputs"); + } + + @TestTemplate + public void testInvalidInputTypes() { + assertThatThrownBy(() -> scalarSql("SELECT system.years(1)")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'years' cannot process input: (int): Expected value to be date or timestamp"); + + assertThatThrownBy(() -> scalarSql("SELECT system.years(1L)")) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + "Function 'years' cannot process input: (bigint): Expected value to be date or timestamp"); + } + + @TestTemplate + public void testThatMagicFunctionsAreInvoked() { + String dateValue = "date('2017-12-01')"; + String dateTransformClass = YearsFunction.DateToYearsFunction.class.getName(); + assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.years(%s)", dateValue)) + .asString() + .isNotNull() + .contains("staticinvoke(class " + dateTransformClass); + + String timestampValue = "TIMESTAMP '2017-12-01 10:12:55.038194 UTC+00:00'"; + String timestampTransformClass = YearsFunction.TimestampToYearsFunction.class.getName(); + assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.years(%s)", timestampValue)) + .asString() + .isNotNull() + .contains("staticinvoke(class " + timestampTransformClass); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestStoragePartitionedJoins.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestStoragePartitionedJoins.java new file mode 100644 index 000000000000..6719c45ca961 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestStoragePartitionedJoins.java @@ -0,0 +1,675 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import static org.apache.iceberg.PlanningMode.DISTRIBUTED; +import static org.apache.iceberg.PlanningMode.LOCAL; +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import org.apache.commons.lang3.StringUtils; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.PlanningMode; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.spark.TestBaseWithCatalog; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.types.StructType; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestStoragePartitionedJoins extends TestBaseWithCatalog { + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}, planningMode = {3}") + public static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.HADOOP.catalogName(), + SparkCatalogConfig.HADOOP.implementation(), + SparkCatalogConfig.HADOOP.properties(), + LOCAL + }, + { + SparkCatalogConfig.HADOOP.catalogName(), + SparkCatalogConfig.HADOOP.implementation(), + SparkCatalogConfig.HADOOP.properties(), + DISTRIBUTED + }, + }; + } + + private static final String OTHER_TABLE_NAME = "other_table"; + + // open file cost and split size are set as 16 MB to produce a split per file + private static final Map TABLE_PROPERTIES = + ImmutableMap.of( + TableProperties.SPLIT_SIZE, "16777216", TableProperties.SPLIT_OPEN_FILE_COST, "16777216"); + + // only v2 bucketing and preserve data grouping properties have to be enabled to trigger SPJ + // other properties are only to simplify testing and validation + private static final Map ENABLED_SPJ_SQL_CONF = + ImmutableMap.of( + SQLConf.V2_BUCKETING_ENABLED().key(), + "true", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED().key(), + "true", + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION().key(), + "false", + SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), + "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD().key(), + "-1", + SparkSQLProperties.PRESERVE_DATA_GROUPING, + "true"); + + private static final Map DISABLED_SPJ_SQL_CONF = + ImmutableMap.of( + SQLConf.V2_BUCKETING_ENABLED().key(), + "false", + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION().key(), + "false", + SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), + "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD().key(), + "-1", + SparkSQLProperties.PRESERVE_DATA_GROUPING, + "true"); + + @Parameter(index = 3) + private PlanningMode planningMode; + + @BeforeAll + public static void setupSparkConf() { + spark.conf().set("spark.sql.shuffle.partitions", "4"); + } + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS %s", tableName(OTHER_TABLE_NAME)); + } + + // TODO: add tests for truncate transforms once SPARK-40295 is released + + @TestTemplate + public void testJoinsWithBucketingOnByteColumn() throws NoSuchTableException { + checkJoin("byte_col", "TINYINT", "bucket(4, byte_col)"); + } + + @TestTemplate + public void testJoinsWithBucketingOnShortColumn() throws NoSuchTableException { + checkJoin("short_col", "SMALLINT", "bucket(4, short_col)"); + } + + @TestTemplate + public void testJoinsWithBucketingOnIntColumn() throws NoSuchTableException { + checkJoin("int_col", "INT", "bucket(16, int_col)"); + } + + @TestTemplate + public void testJoinsWithBucketingOnLongColumn() throws NoSuchTableException { + checkJoin("long_col", "BIGINT", "bucket(16, long_col)"); + } + + @TestTemplate + public void testJoinsWithBucketingOnTimestampColumn() throws NoSuchTableException { + checkJoin("timestamp_col", "TIMESTAMP", "bucket(16, timestamp_col)"); + } + + @TestTemplate + public void testJoinsWithBucketingOnTimestampNtzColumn() throws NoSuchTableException { + checkJoin("timestamp_col", "TIMESTAMP_NTZ", "bucket(16, timestamp_col)"); + } + + @TestTemplate + public void testJoinsWithBucketingOnDateColumn() throws NoSuchTableException { + checkJoin("date_col", "DATE", "bucket(8, date_col)"); + } + + @TestTemplate + public void testJoinsWithBucketingOnDecimalColumn() throws NoSuchTableException { + checkJoin("decimal_col", "DECIMAL(20, 2)", "bucket(8, decimal_col)"); + } + + @TestTemplate + public void testJoinsWithBucketingOnBinaryColumn() throws NoSuchTableException { + checkJoin("binary_col", "BINARY", "bucket(8, binary_col)"); + } + + @TestTemplate + public void testJoinsWithYearsOnTimestampColumn() throws NoSuchTableException { + checkJoin("timestamp_col", "TIMESTAMP", "years(timestamp_col)"); + } + + @TestTemplate + public void testJoinsWithYearsOnTimestampNtzColumn() throws NoSuchTableException { + checkJoin("timestamp_col", "TIMESTAMP_NTZ", "years(timestamp_col)"); + } + + @TestTemplate + public void testJoinsWithYearsOnDateColumn() throws NoSuchTableException { + checkJoin("date_col", "DATE", "years(date_col)"); + } + + @TestTemplate + public void testJoinsWithMonthsOnTimestampColumn() throws NoSuchTableException { + checkJoin("timestamp_col", "TIMESTAMP", "months(timestamp_col)"); + } + + @TestTemplate + public void testJoinsWithMonthsOnTimestampNtzColumn() throws NoSuchTableException { + checkJoin("timestamp_col", "TIMESTAMP_NTZ", "months(timestamp_col)"); + } + + @TestTemplate + public void testJoinsWithMonthsOnDateColumn() throws NoSuchTableException { + checkJoin("date_col", "DATE", "months(date_col)"); + } + + @TestTemplate + public void testJoinsWithDaysOnTimestampColumn() throws NoSuchTableException { + checkJoin("timestamp_col", "TIMESTAMP", "days(timestamp_col)"); + } + + @TestTemplate + public void testJoinsWithDaysOnTimestampNtzColumn() throws NoSuchTableException { + checkJoin("timestamp_col", "TIMESTAMP_NTZ", "days(timestamp_col)"); + } + + @TestTemplate + public void testJoinsWithDaysOnDateColumn() throws NoSuchTableException { + checkJoin("date_col", "DATE", "days(date_col)"); + } + + @TestTemplate + public void testJoinsWithHoursOnTimestampColumn() throws NoSuchTableException { + checkJoin("timestamp_col", "TIMESTAMP", "hours(timestamp_col)"); + } + + @TestTemplate + public void testJoinsWithHoursOnTimestampNtzColumn() throws NoSuchTableException { + checkJoin("timestamp_col", "TIMESTAMP_NTZ", "hours(timestamp_col)"); + } + + @TestTemplate + public void testJoinsWithMultipleTransformTypes() throws NoSuchTableException { + String createTableStmt = + "CREATE TABLE %s (" + + " id BIGINT, int_col INT, date_col1 DATE, date_col2 DATE, date_col3 DATE," + + " timestamp_col TIMESTAMP, string_col STRING, dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (" + + " years(date_col1), months(date_col2), days(date_col3), hours(timestamp_col), " + + " bucket(8, int_col), dep)" + + "TBLPROPERTIES (%s)"; + + sql(createTableStmt, tableName, tablePropsAsString(TABLE_PROPERTIES)); + sql(createTableStmt, tableName(OTHER_TABLE_NAME), tablePropsAsString(TABLE_PROPERTIES)); + + Table table = validationCatalog.loadTable(tableIdent); + + Dataset dataDF = randomDataDF(table.schema(), 16); + + // write to the first table 1 time to generate 1 file per partition + append(tableName, dataDF); + + // write to the second table 2 times to generate 2 files per partition + append(tableName(OTHER_TABLE_NAME), dataDF); + append(tableName(OTHER_TABLE_NAME), dataDF); + + // Spark SPJ support is limited at the moment and requires all source partitioning columns, + // which were projected in the query, to be part of the join condition + // suppose a table is partitioned by `p1`, `bucket(8, pk)` + // queries covering `p1` and `pk` columns must include equality predicates + // on both `p1` and `pk` to benefit from SPJ + // this is a temporary Spark limitation that will be removed in a future release + + assertPartitioningAwarePlan( + 1, /* expected num of shuffles with SPJ */ + 3, /* expected num of shuffles without SPJ */ + "SELECT t1.id " + + "FROM %s t1 " + + "INNER JOIN %s t2 " + + "ON t1.id = t2.id AND t1.dep = t2.dep " + + "ORDER BY t1.id", + tableName, + tableName(OTHER_TABLE_NAME)); + + assertPartitioningAwarePlan( + 1, /* expected num of shuffles with SPJ */ + 3, /* expected num of shuffles without SPJ */ + "SELECT t1.id, t1.int_col, t1.date_col1 " + + "FROM %s t1 " + + "INNER JOIN %s t2 " + + "ON t1.id = t2.id AND t1.int_col = t2.int_col AND t1.date_col1 = t2.date_col1 " + + "ORDER BY t1.id, t1.int_col, t1.date_col1", + tableName, + tableName(OTHER_TABLE_NAME)); + + assertPartitioningAwarePlan( + 1, /* expected num of shuffles with SPJ */ + 3, /* expected num of shuffles without SPJ */ + "SELECT t1.id, t1.timestamp_col, t1.string_col " + + "FROM %s t1 " + + "INNER JOIN %s t2 " + + "ON t1.id = t2.id AND t1.timestamp_col = t2.timestamp_col AND t1.string_col = t2.string_col " + + "ORDER BY t1.id, t1.timestamp_col, t1.string_col", + tableName, + tableName(OTHER_TABLE_NAME)); + + assertPartitioningAwarePlan( + 1, /* expected num of shuffles with SPJ */ + 3, /* expected num of shuffles without SPJ */ + "SELECT t1.id, t1.date_col1, t1.date_col2, t1.date_col3 " + + "FROM %s t1 " + + "INNER JOIN %s t2 " + + "ON t1.id = t2.id AND t1.date_col1 = t2.date_col1 AND t1.date_col2 = t2.date_col2 AND t1.date_col3 = t2.date_col3 " + + "ORDER BY t1.id, t1.date_col1, t1.date_col2, t1.date_col3", + tableName, + tableName(OTHER_TABLE_NAME)); + + assertPartitioningAwarePlan( + 1, /* expected num of shuffles with SPJ */ + 3, /* expected num of shuffles without SPJ */ + "SELECT t1.id, t1.int_col, t1.timestamp_col, t1.dep " + + "FROM %s t1 " + + "INNER JOIN %s t2 " + + "ON t1.id = t2.id AND t1.int_col = t2.int_col AND t1.timestamp_col = t2.timestamp_col AND t1.dep = t2.dep " + + "ORDER BY t1.id, t1.int_col, t1.timestamp_col, t1.dep", + tableName, + tableName(OTHER_TABLE_NAME)); + } + + @TestTemplate + public void testJoinsWithCompatibleSpecEvolution() { + // create a table with an empty spec + sql( + "CREATE TABLE %s (id BIGINT, int_col INT, dep STRING)" + + "USING iceberg " + + "TBLPROPERTIES (%s)", + tableName, tablePropsAsString(TABLE_PROPERTIES)); + + Table table = validationCatalog.loadTable(tableIdent); + + // evolve the spec in the first table by adding `dep` + table.updateSpec().addField("dep").commit(); + + // insert data into the first table partitioned by `dep` + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO %s VALUES (1L, 100, 'software')", tableName); + + // evolve the spec in the first table by adding `bucket(int_col, 8)` + table.updateSpec().addField(Expressions.bucket("int_col", 8)).commit(); + + // insert data into the first table partitioned by `dep`, `bucket(8, int_col)` + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO %s VALUES (2L, 200, 'hr')", tableName); + + // create another table partitioned by `other_dep` + sql( + "CREATE TABLE %s (other_id BIGINT, other_int_col INT, other_dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (other_dep)" + + "TBLPROPERTIES (%s)", + tableName(OTHER_TABLE_NAME), tablePropsAsString(TABLE_PROPERTIES)); + + // insert data into the second table partitioned by 'other_dep' + sql("INSERT INTO %s VALUES (1L, 100, 'software')", tableName(OTHER_TABLE_NAME)); + sql("INSERT INTO %s VALUES (2L, 200, 'hr')", tableName(OTHER_TABLE_NAME)); + + // SPJ would apply as the grouping keys are compatible + // the first table: `dep` (an intersection of all active partition fields across scanned specs) + // the second table: `other_dep` (the only partition field). + + assertPartitioningAwarePlan( + 1, /* expected num of shuffles with SPJ */ + 3, /* expected num of shuffles without SPJ */ + "SELECT * " + + "FROM %s " + + "INNER JOIN %s " + + "ON id = other_id AND int_col = other_int_col AND dep = other_dep " + + "ORDER BY id, int_col, dep", + tableName, + tableName(OTHER_TABLE_NAME)); + } + + @TestTemplate + public void testJoinsWithIncompatibleSpecs() { + sql( + "CREATE TABLE %s (id BIGINT, int_col INT, dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (dep)" + + "TBLPROPERTIES (%s)", + tableName, tablePropsAsString(TABLE_PROPERTIES)); + + sql("INSERT INTO %s VALUES (1L, 100, 'software')", tableName); + sql("INSERT INTO %s VALUES (2L, 200, 'software')", tableName); + sql("INSERT INTO %s VALUES (3L, 300, 'software')", tableName); + + sql( + "CREATE TABLE %s (id BIGINT, int_col INT, dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (bucket(8, int_col))" + + "TBLPROPERTIES (%s)", + tableName(OTHER_TABLE_NAME), tablePropsAsString(TABLE_PROPERTIES)); + + sql("INSERT INTO %s VALUES (1L, 100, 'software')", tableName(OTHER_TABLE_NAME)); + sql("INSERT INTO %s VALUES (2L, 200, 'software')", tableName(OTHER_TABLE_NAME)); + sql("INSERT INTO %s VALUES (3L, 300, 'software')", tableName(OTHER_TABLE_NAME)); + + // queries can't benefit from SPJ as specs are not compatible + // the first table: `dep` + // the second table: `bucket(8, int_col)` + + assertPartitioningAwarePlan( + 3, /* expected num of shuffles with SPJ */ + 3, /* expected num of shuffles with SPJ */ + "SELECT * " + + "FROM %s t1 " + + "INNER JOIN %s t2 " + + "ON t1.id = t2.id AND t1.int_col = t2.int_col AND t1.dep = t2.dep " + + "ORDER BY t1.id, t1.int_col, t1.dep, t2.id, t2.int_col, t2.dep", + tableName, + tableName(OTHER_TABLE_NAME)); + } + + @TestTemplate + public void testJoinsWithUnpartitionedTables() { + sql( + "CREATE TABLE %s (id BIGINT, int_col INT, dep STRING)" + + "USING iceberg " + + "TBLPROPERTIES (" + + " 'read.split.target-size' = 16777216," + + " 'read.split.open-file-cost' = 16777216)", + tableName); + + sql("INSERT INTO %s VALUES (1L, 100, 'software')", tableName); + sql("INSERT INTO %s VALUES (2L, 200, 'software')", tableName); + sql("INSERT INTO %s VALUES (3L, 300, 'software')", tableName); + + sql( + "CREATE TABLE %s (id BIGINT, int_col INT, dep STRING)" + + "USING iceberg " + + "TBLPROPERTIES (" + + " 'read.split.target-size' = 16777216," + + " 'read.split.open-file-cost' = 16777216)", + tableName(OTHER_TABLE_NAME)); + + sql("INSERT INTO %s VALUES (1L, 100, 'software')", tableName(OTHER_TABLE_NAME)); + sql("INSERT INTO %s VALUES (2L, 200, 'software')", tableName(OTHER_TABLE_NAME)); + sql("INSERT INTO %s VALUES (3L, 300, 'software')", tableName(OTHER_TABLE_NAME)); + + // queries covering unpartitioned tables can't benefit from SPJ but shouldn't fail + + assertPartitioningAwarePlan( + 3, /* expected num of shuffles with SPJ */ + 3, /* expected num of shuffles without SPJ */ + "SELECT * " + + "FROM %s t1 " + + "INNER JOIN %s t2 " + + "ON t1.id = t2.id AND t1.int_col = t2.int_col AND t1.dep = t2.dep " + + "ORDER BY t1.id, t1.int_col, t1.dep, t2.id, t2.int_col, t2.dep", + tableName, + tableName(OTHER_TABLE_NAME)); + } + + @TestTemplate + public void testJoinsWithEmptyTable() { + sql( + "CREATE TABLE %s (id BIGINT, int_col INT, dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (dep)" + + "TBLPROPERTIES (%s)", + tableName, tablePropsAsString(TABLE_PROPERTIES)); + + sql( + "CREATE TABLE %s (id BIGINT, int_col INT, dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (dep)" + + "TBLPROPERTIES (%s)", + tableName(OTHER_TABLE_NAME), tablePropsAsString(TABLE_PROPERTIES)); + + sql("INSERT INTO %s VALUES (1L, 100, 'software')", tableName(OTHER_TABLE_NAME)); + sql("INSERT INTO %s VALUES (2L, 200, 'software')", tableName(OTHER_TABLE_NAME)); + sql("INSERT INTO %s VALUES (3L, 300, 'software')", tableName(OTHER_TABLE_NAME)); + + assertPartitioningAwarePlan( + 3, /* expected num of shuffles with SPJ */ + 3, /* expected num of shuffles without SPJ */ + "SELECT * " + + "FROM %s t1 " + + "INNER JOIN %s t2 " + + "ON t1.id = t2.id AND t1.int_col = t2.int_col AND t1.dep = t2.dep " + + "ORDER BY t1.id, t1.int_col, t1.dep, t2.id, t2.int_col, t2.dep", + tableName, + tableName(OTHER_TABLE_NAME)); + } + + @TestTemplate + public void testJoinsWithOneSplitTables() { + sql( + "CREATE TABLE %s (id BIGINT, int_col INT, dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (dep)" + + "TBLPROPERTIES (%s)", + tableName, tablePropsAsString(TABLE_PROPERTIES)); + + sql("INSERT INTO %s VALUES (1L, 100, 'software')", tableName); + + sql( + "CREATE TABLE %s (id BIGINT, int_col INT, dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (dep)" + + "TBLPROPERTIES (%s)", + tableName(OTHER_TABLE_NAME), tablePropsAsString(TABLE_PROPERTIES)); + + sql("INSERT INTO %s VALUES (1L, 100, 'software')", tableName(OTHER_TABLE_NAME)); + + assertPartitioningAwarePlan( + 1, /* expected num of shuffles with SPJ */ + 3, /* expected num of shuffles without SPJ */ + "SELECT * " + + "FROM %s t1 " + + "INNER JOIN %s t2 " + + "ON t1.id = t2.id AND t1.int_col = t2.int_col AND t1.dep = t2.dep " + + "ORDER BY t1.id, t1.int_col, t1.dep, t2.id, t2.int_col, t2.dep", + tableName, + tableName(OTHER_TABLE_NAME)); + } + + @TestTemplate + public void testJoinsWithMismatchingPartitionKeys() { + sql( + "CREATE TABLE %s (id BIGINT, int_col INT, dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (dep)" + + "TBLPROPERTIES (%s)", + tableName, tablePropsAsString(TABLE_PROPERTIES)); + + sql("INSERT INTO %s VALUES (1L, 100, 'software')", tableName); + sql("INSERT INTO %s VALUES (2L, 100, 'hr')", tableName); + + sql( + "CREATE TABLE %s (id BIGINT, int_col INT, dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (dep)" + + "TBLPROPERTIES (%s)", + tableName(OTHER_TABLE_NAME), tablePropsAsString(TABLE_PROPERTIES)); + + sql("INSERT INTO %s VALUES (1L, 100, 'software')", tableName(OTHER_TABLE_NAME)); + sql("INSERT INTO %s VALUES (3L, 300, 'hardware')", tableName(OTHER_TABLE_NAME)); + + assertPartitioningAwarePlan( + 1, /* expected num of shuffles with SPJ */ + 3, /* expected num of shuffles without SPJ */ + "SELECT * " + + "FROM %s t1 " + + "INNER JOIN %s t2 " + + "ON t1.id = t2.id AND t1.dep = t2.dep " + + "ORDER BY t1.id, t1.int_col, t1.dep, t2.id, t2.int_col, t2.dep", + tableName, + tableName(OTHER_TABLE_NAME)); + } + + @TestTemplate + public void testAggregates() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id BIGINT, int_col INT, dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (dep, bucket(8, int_col))" + + "TBLPROPERTIES (%s)", + tableName, tablePropsAsString(TABLE_PROPERTIES)); + + // write to the table 3 times to generate 3 files per partition + Table table = validationCatalog.loadTable(tableIdent); + Dataset dataDF = randomDataDF(table.schema(), 100); + append(tableName, dataDF); + + assertPartitioningAwarePlan( + 1, /* expected num of shuffles with SPJ */ + 3, /* expected num of shuffles without SPJ */ + "SELECT COUNT (DISTINCT id) AS count FROM %s GROUP BY dep, int_col ORDER BY count", + tableName, + tableName(OTHER_TABLE_NAME)); + + assertPartitioningAwarePlan( + 1, /* expected num of shuffles with SPJ */ + 3, /* expected num of shuffles without SPJ */ + "SELECT COUNT (DISTINCT id) AS count FROM %s GROUP BY dep ORDER BY count", + tableName, + tableName(OTHER_TABLE_NAME)); + } + + private void checkJoin(String sourceColumnName, String sourceColumnType, String transform) + throws NoSuchTableException { + + String createTableStmt = + "CREATE TABLE %s (id BIGINT, salary INT, %s %s)" + + "USING iceberg " + + "PARTITIONED BY (%s)" + + "TBLPROPERTIES (%s)"; + + sql( + createTableStmt, + tableName, + sourceColumnName, + sourceColumnType, + transform, + tablePropsAsString(TABLE_PROPERTIES)); + configurePlanningMode(tableName, planningMode); + + sql( + createTableStmt, + tableName(OTHER_TABLE_NAME), + sourceColumnName, + sourceColumnType, + transform, + tablePropsAsString(TABLE_PROPERTIES)); + configurePlanningMode(tableName(OTHER_TABLE_NAME), planningMode); + + Table table = validationCatalog.loadTable(tableIdent); + Dataset dataDF = randomDataDF(table.schema(), 200); + append(tableName, dataDF); + append(tableName(OTHER_TABLE_NAME), dataDF); + + assertPartitioningAwarePlan( + 1, /* expected num of shuffles with SPJ */ + 3, /* expected num of shuffles without SPJ */ + "SELECT t1.id, t1.salary, t1.%s " + + "FROM %s t1 " + + "INNER JOIN %s t2 " + + "ON t1.id = t2.id AND t1.%s = t2.%s " + + "ORDER BY t1.id, t1.%s", + sourceColumnName, + tableName, + tableName(OTHER_TABLE_NAME), + sourceColumnName, + sourceColumnName, + sourceColumnName); + } + + private void assertPartitioningAwarePlan( + int expectedNumShufflesWithSPJ, + int expectedNumShufflesWithoutSPJ, + String query, + Object... args) { + + AtomicReference> rowsWithSPJ = new AtomicReference<>(); + AtomicReference> rowsWithoutSPJ = new AtomicReference<>(); + + withSQLConf( + ENABLED_SPJ_SQL_CONF, + () -> { + String plan = executeAndKeepPlan(query, args).toString(); + int actualNumShuffles = StringUtils.countMatches(plan, "Exchange"); + assertThat(actualNumShuffles) + .as("Number of shuffles with enabled SPJ must match") + .isEqualTo(expectedNumShufflesWithSPJ); + + rowsWithSPJ.set(sql(query, args)); + }); + + withSQLConf( + DISABLED_SPJ_SQL_CONF, + () -> { + String plan = executeAndKeepPlan(query, args).toString(); + int actualNumShuffles = StringUtils.countMatches(plan, "Exchange"); + assertThat(actualNumShuffles) + .as("Number of shuffles with disabled SPJ must match") + .isEqualTo(expectedNumShufflesWithoutSPJ); + + rowsWithoutSPJ.set(sql(query, args)); + }); + + assertEquals("SPJ should not change query output", rowsWithoutSPJ.get(), rowsWithSPJ.get()); + } + + private Dataset randomDataDF(Schema schema, int numRows) { + Iterable rows = RandomData.generateSpark(schema, numRows, 0); + JavaRDD rowRDD = sparkContext.parallelize(Lists.newArrayList(rows)); + StructType rowSparkType = SparkSchemaUtil.convert(schema); + return spark.internalCreateDataFrame(JavaRDD.toRDD(rowRDD), rowSparkType, false); + } + + private void append(String table, Dataset df) throws NoSuchTableException { + // fanout writes are enabled as write-time clustering is not supported without Spark extensions + df.coalesce(1).writeTo(table).option(SparkWriteOptions.FANOUT_ENABLED, "true").append(); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestTimestampWithoutZone.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestTimestampWithoutZone.java new file mode 100644 index 000000000000..44d895dd44c5 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestTimestampWithoutZone.java @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.sql.Timestamp; +import java.time.LocalDateTime; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.relocated.com.google.common.base.Joiner; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.CatalogTestBase; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.joda.time.DateTime; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestTimestampWithoutZone extends CatalogTestBase { + + private static final String NEW_TABLE_NAME = "created_table"; + private static final Schema SCHEMA = + new Schema( + Types.NestedField.required(1, "id", Types.LongType.get()), + Types.NestedField.required(2, "ts", Types.TimestampType.withoutZone()), + Types.NestedField.required(3, "tsz", Types.TimestampType.withZone())); + + private final List values = + ImmutableList.of( + row(1L, toLocalDateTime("2021-01-01T00:00:00.0"), toTimestamp("2021-02-01T00:00:00.0")), + row(2L, toLocalDateTime("2021-01-01T00:00:00.0"), toTimestamp("2021-02-01T00:00:00.0")), + row(3L, toLocalDateTime("2021-01-01T00:00:00.0"), toTimestamp("2021-02-01T00:00:00.0"))); + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + public static Object[][] parameters() { + return new Object[][] { + { + "spark_catalog", + SparkSessionCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "parquet-enabled", "true", + "cache-enabled", "false") + } + }; + } + + @BeforeEach + public void createTables() { + validationCatalog.createTable(tableIdent, SCHEMA); + } + + @AfterEach + public void removeTables() { + validationCatalog.dropTable(tableIdent, true); + sql("DROP TABLE IF EXISTS %s", NEW_TABLE_NAME); + } + + /* + Spark does not really care about the timezone, it will just convert it + + spark-sql (default)> CREATE TABLE t1 (tz TIMESTAMP, ntz TIMESTAMP_NTZ); + Time taken: 1.925 seconds + + spark-sql (default)> INSERT INTO t1 VALUES(timestamp '2020-01-01T00:00:00+02:00', timestamp_ntz '2020-01-01T00:00:00'); + Time taken: 1.355 seconds + spark-sql (default)> INSERT INTO t1 VALUES(timestamp_ntz '2020-01-01T00:00:00+02:00', timestamp_ntz '2020-01-01T00:00:00'); + Time taken: 0.129 seconds + spark-sql (default)> INSERT INTO t1 VALUES(timestamp_ntz '2020-01-01T00:00:00+02:00', timestamp '2020-01-01T00:00:00'); + Time taken: 0.125 seconds + spark-sql (default)> INSERT INTO t1 VALUES(timestamp '2020-01-01T00:00:00+02:00', timestamp '2020-01-01T00:00:00'); + Time taken: 0.122 seconds + + spark-sql (default)> select * from t1; + 2020-01-01 00:00:00 2020-01-01 00:00:00 + 2020-01-01 00:00:00 2020-01-01 00:00:00 + 2019-12-31 23:00:00 2020-01-01 00:00:00 + 2019-12-31 23:00:00 2020-01-01 00:00:00 + Time taken: 0.32 seconds, Fetched 4 row(s) + + spark-sql (default)> SELECT count(1) FROM t1 JOIN t1 as t2 ON t1.tz = t2.ntz; + 8 + */ + + @TestTemplate + public void testAppendTimestampWithoutZone() { + // Both NTZ + sql( + "INSERT INTO %s VALUES %s", + tableName, + rowToSqlValues( + ImmutableList.of( + row( + 1L, + toLocalDateTime("2021-01-01T00:00:00.0"), + toLocalDateTime("2021-02-01T00:00:00.0"))))); + } + + @TestTemplate + public void testAppendTimestampWithZone() { + // Both TZ + sql( + "INSERT INTO %s VALUES %s", + tableName, + rowToSqlValues( + ImmutableList.of( + row( + 1L, + toTimestamp("2021-01-01T00:00:00.0"), + toTimestamp("2021-02-01T00:00:00.0"))))); + } + + @TestTemplate + public void testCreateAsSelectWithTimestampWithoutZone() { + sql("INSERT INTO %s VALUES %s", tableName, rowToSqlValues(values)); + + sql("CREATE TABLE %s USING iceberg AS SELECT * FROM %s", NEW_TABLE_NAME, tableName); + + assertThat(scalarSql("SELECT count(*) FROM %s", NEW_TABLE_NAME)) + .as("Should have " + values.size() + " row") + .isEqualTo((long) values.size()); + + assertEquals( + "Row data should match expected", + sql("SELECT * FROM %s ORDER BY id", tableName), + sql("SELECT * FROM %s ORDER BY id", NEW_TABLE_NAME)); + } + + @TestTemplate + public void testCreateNewTableShouldHaveTimestampWithZoneIcebergType() { + sql("INSERT INTO %s VALUES %s", tableName, rowToSqlValues(values)); + + sql("CREATE TABLE %s USING iceberg AS SELECT * FROM %s", NEW_TABLE_NAME, tableName); + + assertThat(scalarSql("SELECT count(*) FROM %s", NEW_TABLE_NAME)) + .as("Should have " + values.size() + " row") + .isEqualTo((long) values.size()); + + assertEquals( + "Data from created table should match data from base table", + sql("SELECT * FROM %s ORDER BY id", tableName), + sql("SELECT * FROM %s ORDER BY id", NEW_TABLE_NAME)); + + Table createdTable = validationCatalog.loadTable(TableIdentifier.of("default", NEW_TABLE_NAME)); + assertFieldsType(createdTable.schema(), Types.TimestampType.withoutZone(), "ts"); + assertFieldsType(createdTable.schema(), Types.TimestampType.withZone(), "tsz"); + } + + @TestTemplate + public void testCreateNewTableShouldHaveTimestampWithoutZoneIcebergType() { + spark + .sessionState() + .catalogManager() + .currentCatalog() + .initialize(catalog.name(), new CaseInsensitiveStringMap(catalogConfig)); + sql("INSERT INTO %s VALUES %s", tableName, rowToSqlValues(values)); + + sql("CREATE TABLE %s USING iceberg AS SELECT * FROM %s", NEW_TABLE_NAME, tableName); + + assertThat(scalarSql("SELECT count(*) FROM %s", NEW_TABLE_NAME)) + .as("Should have " + values.size() + " row") + .isEqualTo((long) values.size()); + + assertEquals( + "Row data should match expected", + sql("SELECT * FROM %s ORDER BY id", tableName), + sql("SELECT * FROM %s ORDER BY id", NEW_TABLE_NAME)); + Table createdTable = validationCatalog.loadTable(TableIdentifier.of("default", NEW_TABLE_NAME)); + assertFieldsType(createdTable.schema(), Types.TimestampType.withoutZone(), "ts"); + assertFieldsType(createdTable.schema(), Types.TimestampType.withZone(), "tsz"); + } + + private Timestamp toTimestamp(String value) { + return new Timestamp(DateTime.parse(value).getMillis()); + } + + private LocalDateTime toLocalDateTime(String value) { + return LocalDateTime.parse(value); + } + + private String rowToSqlValues(List rows) { + List rowValues = + rows.stream() + .map( + row -> { + List columns = + Arrays.stream(row) + .map( + value -> { + if (value instanceof Long) { + return value.toString(); + } else if (value instanceof Timestamp) { + return String.format("timestamp '%s'", value); + } else if (value instanceof LocalDateTime) { + return String.format("timestamp_ntz '%s'", value); + } + throw new RuntimeException("Type is not supported"); + }) + .collect(Collectors.toList()); + return "(" + Joiner.on(",").join(columns) + ")"; + }) + .collect(Collectors.toList()); + return Joiner.on(",").join(rowValues); + } + + private void assertFieldsType(Schema actual, Type.PrimitiveType expected, String... fields) { + actual + .select(fields) + .asStruct() + .fields() + .forEach(field -> assertThat(field.type()).isEqualTo(expected)); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestUnpartitionedWrites.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestUnpartitionedWrites.java new file mode 100644 index 000000000000..7d9dfe95efc0 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestUnpartitionedWrites.java @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +public class TestUnpartitionedWrites extends UnpartitionedWritesTestBase {} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestUnpartitionedWritesToBranch.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestUnpartitionedWritesToBranch.java new file mode 100644 index 000000000000..3df5e9cdf5da --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestUnpartitionedWritesToBranch.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestUnpartitionedWritesToBranch extends UnpartitionedWritesTestBase { + + private static final String BRANCH = "test"; + + @Override + @BeforeEach + public void createTables() { + super.createTables(); + Table table = validationCatalog.loadTable(tableIdent); + table.manageSnapshots().createBranch(BRANCH, table.currentSnapshot().snapshotId()).commit(); + sql("REFRESH TABLE " + tableName); + } + + @Override + protected String commitTarget() { + return String.format("%s.branch_%s", tableName, BRANCH); + } + + @Override + protected String selectTarget() { + return String.format("%s VERSION AS OF '%s'", tableName, BRANCH); + } + + @TestTemplate + public void testInsertIntoNonExistingBranchFails() { + assertThatThrownBy( + () -> sql("INSERT INTO %s.branch_not_exist VALUES (4, 'd'), (5, 'e')", tableName)) + .isInstanceOf(ValidationException.class) + .hasMessage("Cannot use branch (does not exist): not_exist"); + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/UnpartitionedWritesTestBase.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/UnpartitionedWritesTestBase.java new file mode 100644 index 000000000000..ab87b89a3529 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/UnpartitionedWritesTestBase.java @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assumptions.assumeThat; + +import java.util.List; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.CatalogTestBase; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.functions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; + +public abstract class UnpartitionedWritesTestBase extends CatalogTestBase { + + @BeforeEach + public void createTables() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + sql("INSERT INTO %s VALUES (1, 'a'), (2, 'b'), (3, 'c')", tableName); + } + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testInsertAppend() { + assertThat(scalarSql("SELECT count(*) FROM %s", selectTarget())) + .as("Should have 3 rows") + .isEqualTo(3L); + + sql("INSERT INTO %s VALUES (4, 'd'), (5, 'e')", commitTarget()); + + assertThat(scalarSql("SELECT count(*) FROM %s", selectTarget())) + .as("Should have 5 rows") + .isEqualTo(5L); + + List expected = + ImmutableList.of(row(1L, "a"), row(2L, "b"), row(3L, "c"), row(4L, "d"), row(5L, "e")); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testInsertOverwrite() { + assertThat(scalarSql("SELECT count(*) FROM %s", selectTarget())) + .as("Should have 3 rows") + .isEqualTo(3L); + + sql("INSERT OVERWRITE %s VALUES (4, 'd'), (5, 'e')", commitTarget()); + + assertThat(scalarSql("SELECT count(*) FROM %s", selectTarget())) + .as("Should have 2 rows after overwrite") + .isEqualTo(2L); + + List expected = ImmutableList.of(row(4L, "d"), row(5L, "e")); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testInsertAppendAtSnapshot() { + assumeThat(tableName.equals(commitTarget())).isTrue(); + long snapshotId = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); + String prefix = "snapshot_id_"; + + assertThatThrownBy( + () -> + sql("INSERT INTO %s.%s VALUES (4, 'd'), (5, 'e')", tableName, prefix + snapshotId)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageStartingWith("Cannot write to table at a specific snapshot"); + } + + @TestTemplate + public void testInsertOverwriteAtSnapshot() { + assumeThat(tableName.equals(commitTarget())).isTrue(); + long snapshotId = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); + String prefix = "snapshot_id_"; + + assertThatThrownBy( + () -> + sql( + "INSERT OVERWRITE %s.%s VALUES (4, 'd'), (5, 'e')", + tableName, prefix + snapshotId)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageStartingWith("Cannot write to table at a specific snapshot"); + } + + @TestTemplate + public void testDataFrameV2Append() throws NoSuchTableException { + assertThat(scalarSql("SELECT count(*) FROM %s", selectTarget())) + .as("Should have 3 rows") + .isEqualTo(3L); + + List data = ImmutableList.of(new SimpleRecord(4, "d"), new SimpleRecord(5, "e")); + Dataset ds = spark.createDataFrame(data, SimpleRecord.class); + + ds.writeTo(commitTarget()).append(); + + assertThat(scalarSql("SELECT count(*) FROM %s", selectTarget())) + .as("Should have 5 rows after insert") + .isEqualTo(5L); + + List expected = + ImmutableList.of(row(1L, "a"), row(2L, "b"), row(3L, "c"), row(4L, "d"), row(5L, "e")); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testDataFrameV2DynamicOverwrite() throws NoSuchTableException { + assertThat(scalarSql("SELECT count(*) FROM %s", selectTarget())) + .as("Should have 3 rows") + .isEqualTo(3L); + + List data = ImmutableList.of(new SimpleRecord(4, "d"), new SimpleRecord(5, "e")); + Dataset ds = spark.createDataFrame(data, SimpleRecord.class); + + ds.writeTo(commitTarget()).overwritePartitions(); + + assertThat(scalarSql("SELECT count(*) FROM %s", selectTarget())) + .as("Should have 2 rows after overwrite") + .isEqualTo(2L); + + List expected = ImmutableList.of(row(4L, "d"), row(5L, "e")); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testDataFrameV2Overwrite() throws NoSuchTableException { + assertThat(scalarSql("SELECT count(*) FROM %s", selectTarget())) + .as("Should have 3 rows") + .isEqualTo(3L); + + List data = ImmutableList.of(new SimpleRecord(4, "d"), new SimpleRecord(5, "e")); + Dataset ds = spark.createDataFrame(data, SimpleRecord.class); + + ds.writeTo(commitTarget()).overwrite(functions.col("id").$less$eq(3)); + + assertThat(scalarSql("SELECT count(*) FROM %s", selectTarget())) + .as("Should have 2 rows after overwrite") + .isEqualTo(2L); + + List expected = ImmutableList.of(row(4L, "d"), row(5L, "e")); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } +} diff --git a/spark/v4.0/spark/src/test/resources/decimal_dict_and_plain_encoding.parquet b/spark/v4.0/spark/src/test/resources/decimal_dict_and_plain_encoding.parquet new file mode 100644 index 000000000000..48b3bd1bf24f Binary files /dev/null and b/spark/v4.0/spark/src/test/resources/decimal_dict_and_plain_encoding.parquet differ