diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 83a1578640699..c7be95109ec7d 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -402,6 +402,8 @@ jobs: runs-on: ubuntu-latest permissions: packages: write + env: + DOCKER_BUILD_RECORD_UPLOAD: false steps: - name: Login to GitHub Container Registry uses: docker/login-action@v3 diff --git a/LICENSE-binary b/LICENSE-binary index c8bd77e7ae2ec..1fc531d800522 100644 --- a/LICENSE-binary +++ b/LICENSE-binary @@ -215,8 +215,10 @@ com.github.joshelser:dropwizard-metrics-hadoop-metrics2-reporter com.google.code.findbugs:jsr305 com.google.code.gson:gson com.google.crypto.tink:tink +com.google.errorprone:error_prone_annotations com.google.flatbuffers:flatbuffers-java com.google.guava:guava +com.google.j2objc:j2objc-annotations com.jamesmurty.utils:java-xmlbuilder com.ning:compress-lzf com.squareup.okhttp3:logging-interceptor @@ -226,7 +228,7 @@ com.tdunning:json com.twitter:chill-java com.twitter:chill_2.13 com.univocity:univocity-parsers -com.zaxxer.HikariCP +com.zaxxer:HikariCP commons-cli:commons-cli commons-codec:commons-codec commons-collections:commons-collections @@ -273,6 +275,7 @@ io.jsonwebtoken:jjwt-jackson io.netty:netty-all io.netty:netty-buffer io.netty:netty-codec +io.netty:netty-codec-dns io.netty:netty-codec-http io.netty:netty-codec-http2 io.netty:netty-codec-socks @@ -280,6 +283,7 @@ io.netty:netty-common io.netty:netty-handler io.netty:netty-handler-proxy io.netty:netty-resolver +io.netty:netty-resolver-dns io.netty:netty-tcnative-boringssl-static io.netty:netty-tcnative-classes io.netty:netty-transport @@ -328,7 +332,6 @@ org.apache.hive:hive-cli org.apache.hive:hive-common org.apache.hive:hive-exec org.apache.hive:hive-jdbc -org.apache.hive:hive-llap-common org.apache.hive:hive-metastore org.apache.hive:hive-serde org.apache.hive:hive-service-rpc @@ -384,6 +387,8 @@ org.glassfish.jersey.core:jersey-client org.glassfish.jersey.core:jersey-common org.glassfish.jersey.core:jersey-server org.glassfish.jersey.inject:jersey-hk2 +org.javassist:javassist +org.jetbrains:annotations org.json4s:json4s-ast_2.13 org.json4s:json4s-core_2.13 org.json4s:json4s-jackson-core_2.13 diff --git a/NOTICE-binary b/NOTICE-binary index 3f36596b9d6d6..a3f302b1cb04d 100644 --- a/NOTICE-binary +++ b/NOTICE-binary @@ -592,9 +592,6 @@ Copyright 2015 The Apache Software Foundation Apache Extras Companion for log4j 1.2. Copyright 2007 The Apache Software Foundation -Hive Metastore -Copyright 2016 The Apache Software Foundation - Apache Commons Logging Copyright 2003-2013 The Apache Software Foundation @@ -969,12 +966,6 @@ The Derby build relies on a jar file supplied by the JSON Simple project, hosted at https://code.google.com/p/json-simple/. The JSON simple jar file is licensed under the Apache 2.0 License. -Hive CLI -Copyright 2016 The Apache Software Foundation - -Hive JDBC -Copyright 2016 The Apache Software Foundation - Chill is a set of Scala extensions for Kryo. Copyright 2012 Twitter, Inc. @@ -1056,9 +1047,6 @@ Copyright 2019 The Apache Software Foundation Hive Query Language Copyright 2019 The Apache Software Foundation -Hive Llap Common -Copyright 2019 The Apache Software Foundation - Hive Metastore Copyright 2019 The Apache Software Foundation @@ -1083,8 +1071,6 @@ Copyright 2019 The Apache Software Foundation Hive Storage API Copyright 2018 The Apache Software Foundation -Hive Vector-Code-Gen Utilities -Copyright 2019 The Apache Software Foundation Apache License diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 691e40a75dcd9..7d6ee17ee8e13 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -6195,6 +6195,12 @@ ], "sqlState" : "42000" }, + "UNSUPPORTED_TIME_PRECISION" : { + "message" : [ + "The seconds precision of the TIME data type is out of the supported range [0, 6]." + ], + "sqlState" : "0A001" + }, "UNSUPPORTED_TYPED_LITERAL" : { "message" : [ "Literals of the type are not supported. Supported types are ." diff --git a/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala b/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala index 110c5f0934286..1901b366074f0 100644 --- a/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala +++ b/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala @@ -87,7 +87,7 @@ object MDC { * Wrapper class for log messages that include a logging context. * This is used as the return type of the string interpolator `LogStringContext`. */ -case class MessageWithContext(message: String, context: java.util.HashMap[String, String]) { +case class MessageWithContext(message: String, context: java.util.Map[String, String]) { def +(mdc: MessageWithContext): MessageWithContext = { val resultMap = new java.util.HashMap(context) resultMap.putAll(mdc.context) @@ -105,7 +105,7 @@ class LogEntry(messageWithContext: => MessageWithContext) { def message: String = cachedMessageWithContext.message - def context: java.util.HashMap[String, String] = cachedMessageWithContext.context + def context: java.util.Map[String, String] = cachedMessageWithContext.context } /** @@ -166,7 +166,7 @@ trait Logging { } } - protected def withLogContext(context: java.util.HashMap[String, String])(body: => Unit): Unit = { + protected def withLogContext(context: java.util.Map[String, String])(body: => Unit): Unit = { // put into thread context only when structured logging is enabled val closeableThreadContextOpt = if (Logging.isStructuredLoggingEnabled) { Some(CloseableThreadContext.putAll(context)) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 50bf0bea87f94..a940bd90d718f 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -307,7 +307,7 @@ private[spark] class Executor( "executor-heartbeater", HEARTBEAT_INTERVAL_MS) - // must be initialized before running startDriverHeartbeat() + // must be initialized before running heartbeater.start() private val heartbeatReceiverRef = RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv) diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorClassLoader.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorClassLoader.scala index 48d7f150ad9bd..6f8138da6f4fb 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorClassLoader.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorClassLoader.scala @@ -59,9 +59,6 @@ class ExecutorClassLoader( val parentLoader = new ParentClassLoader(parent) - // Allows HTTP connect and read timeouts to be controlled for testing / debugging purposes - private[executor] var httpUrlConnectionTimeoutMillis: Int = -1 - private val fetchFn: (String) => InputStream = uri.getScheme() match { case "spark" => getClassFileInputStreamFromSparkRPC case _ => diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index e30380f41566a..df809f4fad745 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -376,9 +376,8 @@ private[spark] object JsonProtocol extends JsonUtils { g.writeNumberField("Task ID", taskId) g.writeNumberField("Stage ID", stageId) g.writeNumberField("Stage Attempt ID", stageAttemptId) - g.writeArrayFieldStart("Accumulator Updates") - updates.foreach(accumulableInfoToJson(_, g)) - g.writeEndArray() + g.writeFieldName("Accumulator Updates") + accumulablesToJson(updates, g) g.writeEndObject() } g.writeEndArray() @@ -496,7 +495,7 @@ private[spark] object JsonProtocol extends JsonUtils { def accumulablesToJson( accumulables: Iterable[AccumulableInfo], g: JsonGenerator, - includeTaskMetricsAccumulators: Boolean = true): Unit = { + includeTaskMetricsAccumulators: Boolean = true): Unit = { g.writeStartArray() accumulables .filterNot { acc => @@ -714,11 +713,8 @@ private[spark] object JsonProtocol extends JsonUtils { reason.foreach(g.writeStringField("Loss Reason", _)) case taskKilled: TaskKilled => g.writeStringField("Kill Reason", taskKilled.reason) - g.writeArrayFieldStart("Accumulator Updates") - taskKilled.accumUpdates.foreach { info => - accumulableInfoToJson(info, g) - } - g.writeEndArray() + g.writeFieldName("Accumulator Updates") + accumulablesToJson(taskKilled.accumUpdates, g) case _ => // no extra fields to write } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index e5e4bcacc70c3..1dcc91af862e9 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -54,6 +54,7 @@ import org.apache.commons.io.IOUtils import org.apache.commons.lang3.{JavaVersion, SystemUtils} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} +import org.apache.hadoop.fs.audit.CommonAuditContext.currentAuditContext import org.apache.hadoop.io.compress.{CompressionCodecFactory, SplittableCompressionCodec} import org.apache.hadoop.ipc.{CallerContext => HadoopCallerContext} import org.apache.hadoop.ipc.CallerContext.{Builder => HadoopCallerContextBuilder} @@ -3171,6 +3172,9 @@ private[util] object CallerContext extends Logging { * specific applications impacting parts of the Hadoop system and potential problems they may be * creating (e.g. overloading NN). As HDFS mentioned in HDFS-9184, for a given HDFS operation, it's * very helpful to track which upper level job issues it. + * The context information is also set in the audit context for cloud storage + * connectors. If supported, this gets marshalled as part of the HTTP Referrer header + * or similar field, and so ends up in the store service logs themselves. * * @param from who sets up the caller context (TASK, CLIENT, APPMASTER) * @@ -3221,11 +3225,15 @@ private[spark] class CallerContext( /** * Set up the caller context [[context]] by invoking Hadoop CallerContext API of - * [[HadoopCallerContext]]. + * [[HadoopCallerContext]], which is included in IPC calls, + * and the Hadoop audit context, which may be included in cloud storage + * requests. */ def setCurrentContext(): Unit = if (CallerContext.callerContextEnabled) { val hdfsContext = new HadoopCallerContextBuilder(context).build() HadoopCallerContext.setCurrent(hdfsContext) + // set the audit context for to object stores, with the prefix "spark" + currentAuditContext.put("spark", context) } } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 89e3d8371be4c..a9399edeb9ad7 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -1166,7 +1166,9 @@ private[spark] object JsonProtocolSuite extends Assertions { assert(taskId1 === taskId2) assert(stageId1 === stageId2) assert(stageAttemptId1 === stageAttemptId2) - assertSeqEquals[AccumulableInfo](updates1, updates2, (a, b) => a.equals(b)) + val filteredUpdates = updates1 + .filterNot { acc => acc.name.exists(accumulableExcludeList.contains) } + assertSeqEquals[AccumulableInfo](filteredUpdates, updates2, (a, b) => a.equals(b)) }) assertSeqEquals[((Int, Int), ExecutorMetrics)]( e1.executorUpdates.toSeq.sortBy(_._1), @@ -1299,7 +1301,9 @@ private[spark] object JsonProtocolSuite extends Assertions { assert(r1.description === r2.description) assertSeqEquals(r1.stackTrace, r2.stackTrace, assertStackTraceElementEquals) assert(r1.fullStackTrace === r2.fullStackTrace) - assertSeqEquals[AccumulableInfo](r1.accumUpdates, r2.accumUpdates, (a, b) => a.equals(b)) + val filteredUpdates = r1.accumUpdates + .filterNot { acc => acc.name.exists(accumulableExcludeList.contains) } + assertSeqEquals[AccumulableInfo](filteredUpdates, r2.accumUpdates, (a, b) => a.equals(b)) case (TaskResultLost, TaskResultLost) => case (r1: TaskKilled, r2: TaskKilled) => assert(r1.reason == r2.reason) @@ -2774,28 +2778,6 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Count Failed Values": true | }, | { - | "ID": 12, - | "Name": "$UPDATED_BLOCK_STATUSES", - | "Update": [ - | { - | "Block ID": "rdd_0_0", - | "Status": { - | "Storage Level": { - | "Use Disk": true, - | "Use Memory": true, - | "Use Off Heap": false, - | "Deserialized": false, - | "Replication": 2 - | }, - | "Memory Size": 0, - | "Disk Size": 0 - | } - | } - | ], - | "Internal": true, - | "Count Failed Values": true - | }, - | { | "ID": 13, | "Name": "${shuffleRead.REMOTE_BLOCKS_FETCHED}", | "Update": 0, diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 3312bd3d5743f..077dd489378fd 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -37,6 +37,7 @@ import org.apache.commons.lang3.SystemUtils import org.apache.commons.math3.stat.inference.ChiSquareTest import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.audit.CommonAuditContext.currentAuditContext import org.apache.hadoop.ipc.{CallerContext => HadoopCallerContext} import org.apache.logging.log4j.Level @@ -1003,9 +1004,19 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties { } test("Set Spark CallerContext") { - val context = "test" - new CallerContext(context).setCurrentContext() - assert(s"SPARK_$context" === HadoopCallerContext.getCurrent.toString) + currentAuditContext.reset + new CallerContext("test", + Some("upstream"), + Some("app"), + Some("attempt"), + Some(1), + Some(2), + Some(3), + Some(4), + Some(5)).setCurrentContext() + val expected = s"SPARK_test_app_attempt_JId_1_SId_2_3_TId_4_5_upstream" + assert(expected === HadoopCallerContext.getCurrent.toString) + assert(expected === currentAuditContext.get("spark")) } test("encodeFileNameToURIRawPath") { diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 08608b057b8a1..eff533572a046 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -129,7 +129,7 @@ javax.servlet-api/4.0.1//javax.servlet-api-4.0.1.jar javolution/5.5.1//javolution-5.5.1.jar jaxb-core/4.0.5//jaxb-core-4.0.5.jar jaxb-runtime/4.0.5//jaxb-runtime-4.0.5.jar -jcl-over-slf4j/2.0.16//jcl-over-slf4j-2.0.16.jar +jcl-over-slf4j/2.0.17//jcl-over-slf4j-2.0.17.jar jdo-api/3.0.1//jdo-api-3.0.1.jar jdom2/2.0.6//jdom2-2.0.6.jar jersey-client/3.0.16//jersey-client-3.0.16.jar @@ -157,7 +157,7 @@ json4s-jackson_2.13/4.0.7//json4s-jackson_2.13-4.0.7.jar json4s-scalap_2.13/4.0.7//json4s-scalap_2.13-4.0.7.jar jsr305/3.0.0//jsr305-3.0.0.jar jta/1.1//jta-1.1.jar -jul-to-slf4j/2.0.16//jul-to-slf4j-2.0.16.jar +jul-to-slf4j/2.0.17//jul-to-slf4j-2.0.17.jar kryo-shaded/4.0.2//kryo-shaded-4.0.2.jar kubernetes-client-api/7.1.0//kubernetes-client-api-7.1.0.jar kubernetes-client/7.1.0//kubernetes-client-7.1.0.jar @@ -201,33 +201,33 @@ metrics-jmx/4.2.30//metrics-jmx-4.2.30.jar metrics-json/4.2.30//metrics-json-4.2.30.jar metrics-jvm/4.2.30//metrics-jvm-4.2.30.jar minlog/1.3.0//minlog-1.3.0.jar -netty-all/4.1.118.Final//netty-all-4.1.118.Final.jar -netty-buffer/4.1.118.Final//netty-buffer-4.1.118.Final.jar -netty-codec-dns/4.1.118.Final//netty-codec-dns-4.1.118.Final.jar -netty-codec-http/4.1.118.Final//netty-codec-http-4.1.118.Final.jar -netty-codec-http2/4.1.118.Final//netty-codec-http2-4.1.118.Final.jar -netty-codec-socks/4.1.118.Final//netty-codec-socks-4.1.118.Final.jar -netty-codec/4.1.118.Final//netty-codec-4.1.118.Final.jar -netty-common/4.1.118.Final//netty-common-4.1.118.Final.jar -netty-handler-proxy/4.1.118.Final//netty-handler-proxy-4.1.118.Final.jar -netty-handler/4.1.118.Final//netty-handler-4.1.118.Final.jar -netty-resolver-dns/4.1.118.Final//netty-resolver-dns-4.1.118.Final.jar -netty-resolver/4.1.118.Final//netty-resolver-4.1.118.Final.jar +netty-all/4.1.119.Final//netty-all-4.1.119.Final.jar +netty-buffer/4.1.119.Final//netty-buffer-4.1.119.Final.jar +netty-codec-dns/4.1.119.Final//netty-codec-dns-4.1.119.Final.jar +netty-codec-http/4.1.119.Final//netty-codec-http-4.1.119.Final.jar +netty-codec-http2/4.1.119.Final//netty-codec-http2-4.1.119.Final.jar +netty-codec-socks/4.1.119.Final//netty-codec-socks-4.1.119.Final.jar +netty-codec/4.1.119.Final//netty-codec-4.1.119.Final.jar +netty-common/4.1.119.Final//netty-common-4.1.119.Final.jar +netty-handler-proxy/4.1.119.Final//netty-handler-proxy-4.1.119.Final.jar +netty-handler/4.1.119.Final//netty-handler-4.1.119.Final.jar +netty-resolver-dns/4.1.119.Final//netty-resolver-dns-4.1.119.Final.jar +netty-resolver/4.1.119.Final//netty-resolver-4.1.119.Final.jar netty-tcnative-boringssl-static/2.0.70.Final/linux-aarch_64/netty-tcnative-boringssl-static-2.0.70.Final-linux-aarch_64.jar netty-tcnative-boringssl-static/2.0.70.Final/linux-x86_64/netty-tcnative-boringssl-static-2.0.70.Final-linux-x86_64.jar netty-tcnative-boringssl-static/2.0.70.Final/osx-aarch_64/netty-tcnative-boringssl-static-2.0.70.Final-osx-aarch_64.jar netty-tcnative-boringssl-static/2.0.70.Final/osx-x86_64/netty-tcnative-boringssl-static-2.0.70.Final-osx-x86_64.jar netty-tcnative-boringssl-static/2.0.70.Final/windows-x86_64/netty-tcnative-boringssl-static-2.0.70.Final-windows-x86_64.jar netty-tcnative-classes/2.0.70.Final//netty-tcnative-classes-2.0.70.Final.jar -netty-transport-classes-epoll/4.1.118.Final//netty-transport-classes-epoll-4.1.118.Final.jar -netty-transport-classes-kqueue/4.1.118.Final//netty-transport-classes-kqueue-4.1.118.Final.jar -netty-transport-native-epoll/4.1.118.Final/linux-aarch_64/netty-transport-native-epoll-4.1.118.Final-linux-aarch_64.jar -netty-transport-native-epoll/4.1.118.Final/linux-riscv64/netty-transport-native-epoll-4.1.118.Final-linux-riscv64.jar -netty-transport-native-epoll/4.1.118.Final/linux-x86_64/netty-transport-native-epoll-4.1.118.Final-linux-x86_64.jar -netty-transport-native-kqueue/4.1.118.Final/osx-aarch_64/netty-transport-native-kqueue-4.1.118.Final-osx-aarch_64.jar -netty-transport-native-kqueue/4.1.118.Final/osx-x86_64/netty-transport-native-kqueue-4.1.118.Final-osx-x86_64.jar -netty-transport-native-unix-common/4.1.118.Final//netty-transport-native-unix-common-4.1.118.Final.jar -netty-transport/4.1.118.Final//netty-transport-4.1.118.Final.jar +netty-transport-classes-epoll/4.1.119.Final//netty-transport-classes-epoll-4.1.119.Final.jar +netty-transport-classes-kqueue/4.1.119.Final//netty-transport-classes-kqueue-4.1.119.Final.jar +netty-transport-native-epoll/4.1.119.Final/linux-aarch_64/netty-transport-native-epoll-4.1.119.Final-linux-aarch_64.jar +netty-transport-native-epoll/4.1.119.Final/linux-riscv64/netty-transport-native-epoll-4.1.119.Final-linux-riscv64.jar +netty-transport-native-epoll/4.1.119.Final/linux-x86_64/netty-transport-native-epoll-4.1.119.Final-linux-x86_64.jar +netty-transport-native-kqueue/4.1.119.Final/osx-aarch_64/netty-transport-native-kqueue-4.1.119.Final-osx-aarch_64.jar +netty-transport-native-kqueue/4.1.119.Final/osx-x86_64/netty-transport-native-kqueue-4.1.119.Final-osx-x86_64.jar +netty-transport-native-unix-common/4.1.119.Final//netty-transport-native-unix-common-4.1.119.Final.jar +netty-transport/4.1.119.Final//netty-transport-4.1.119.Final.jar objenesis/3.3//objenesis-3.3.jar okhttp/3.12.12//okhttp-3.12.12.jar okio/1.17.6//okio-1.17.6.jar @@ -251,7 +251,7 @@ parquet-jackson/1.15.0//parquet-jackson-1.15.0.jar pickle/1.5//pickle-1.5.jar py4j/0.10.9.9//py4j-0.10.9.9.jar remotetea-oncrpc/1.1.2//remotetea-oncrpc-1.1.2.jar -rocksdbjni/9.10.0//rocksdbjni-9.10.0.jar +rocksdbjni/9.8.4//rocksdbjni-9.8.4.jar scala-collection-compat_2.13/2.7.0//scala-collection-compat_2.13-2.7.0.jar scala-compiler/2.13.16//scala-compiler-2.13.16.jar scala-library/2.13.16//scala-library-2.13.16.jar @@ -259,7 +259,7 @@ scala-parallel-collections_2.13/1.2.0//scala-parallel-collections_2.13-1.2.0.jar scala-parser-combinators_2.13/2.4.0//scala-parser-combinators_2.13-2.4.0.jar scala-reflect/2.13.16//scala-reflect-2.13.16.jar scala-xml_2.13/2.3.0//scala-xml_2.13-2.3.0.jar -slf4j-api/2.0.16//slf4j-api-2.0.16.jar +slf4j-api/2.0.17//slf4j-api-2.0.17.jar snakeyaml-engine/2.9//snakeyaml-engine-2.9.jar snakeyaml/2.3//snakeyaml-2.3.jar snappy-java/1.1.10.7//snappy-java-1.1.10.7.jar diff --git a/docs/img/pyspark-components.pptx b/docs/img/pyspark-components.pptx index e0111a44e186e..754b3f1243cc2 100644 Binary files a/docs/img/pyspark-components.pptx and b/docs/img/pyspark-components.pptx differ diff --git a/docs/img/pyspark-machine_learning.png b/docs/img/pyspark-machine_learning.png index 7f4e6286f2046..15dbda2eb2071 100644 Binary files a/docs/img/pyspark-machine_learning.png and b/docs/img/pyspark-machine_learning.png differ diff --git a/docs/img/pyspark-pandas_api_on_spark.png b/docs/img/pyspark-pandas_api_on_spark.png index b4b291b3440f6..610d9ba40f8b0 100644 Binary files a/docs/img/pyspark-pandas_api_on_spark.png and b/docs/img/pyspark-pandas_api_on_spark.png differ diff --git a/docs/img/pyspark-python_spark_connect_client.png b/docs/img/pyspark-python_spark_connect_client.png new file mode 100644 index 0000000000000..2d7e35c0740b1 Binary files /dev/null and b/docs/img/pyspark-python_spark_connect_client.png differ diff --git a/docs/img/pyspark-spark_core_and_rdds.png b/docs/img/pyspark-spark_core_and_rdds.png index 8d06a446c1ad1..a53809438d052 100644 Binary files a/docs/img/pyspark-spark_core_and_rdds.png and b/docs/img/pyspark-spark_core_and_rdds.png differ diff --git a/docs/img/pyspark-spark_sql_and_dataframes.png b/docs/img/pyspark-spark_sql_and_dataframes.png index acd8b280de164..1ff6a5e22e054 100644 Binary files a/docs/img/pyspark-spark_sql_and_dataframes.png and b/docs/img/pyspark-spark_sql_and_dataframes.png differ diff --git a/docs/img/pyspark-structured_streaming.png b/docs/img/pyspark-structured_streaming.png index b49bb5b275509..60ba6457f0434 100644 Binary files a/docs/img/pyspark-structured_streaming.png and b/docs/img/pyspark-structured_streaming.png differ diff --git a/docs/sql-ref-datatypes.md b/docs/sql-ref-datatypes.md index 3a4530dcecaef..1745f348a369e 100644 --- a/docs/sql-ref-datatypes.md +++ b/docs/sql-ref-datatypes.md @@ -46,6 +46,8 @@ Spark SQL and DataFrames support the following data types: * Datetime type - `DateType`: Represents values comprising values of fields year, month and day, without a time-zone. + - `TimeType(precision)`: Represents values comprising values of fields hour, minute and second with the number of decimal digits `precision` following the decimal point in the seconds field, without a time-zone. + The range of values is from `00:00:00` to `23:59:59` for min precision `0`, and to `23:59:59.999999` for max precision `6`. - `TimestampType`: Timestamp with local time zone(TIMESTAMP_LTZ). It represents values comprising values of fields year, month, day, hour, minute, and second, with the session local time-zone. The timestamp value represents an absolute point in time. diff --git a/docs/sql-ref-syntax-aux-describe-table.md b/docs/sql-ref-syntax-aux-describe-table.md index 5f5fd27c865e6..100424a865b87 100644 --- a/docs/sql-ref-syntax-aux-describe-table.md +++ b/docs/sql-ref-syntax-aux-describe-table.md @@ -97,6 +97,8 @@ to return the metadata pertaining to a partition or column respectively. "partition_values": { "": "" }, + "partition_columns": ["col1", "col2"], + "clustering_columns": ["col1", "col2"], "location": "", "view_text": "", "view_original_text": "", diff --git a/docs/web-ui.md b/docs/web-ui.md index 9173ddef81d36..3889b41f03a04 100644 --- a/docs/web-ui.md +++ b/docs/web-ui.md @@ -37,6 +37,7 @@ DAG visualization, and all stages of the job. The information that is displayed in this section is * User: Current Spark user +* Started At: The startup time of Spark application * Total uptime: Time since Spark application started * Scheduling mode: See [job scheduling](job-scheduling.html#configuring-pool-properties) * Number of jobs per status: Active, Completed, Failed diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/ANOVATest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/ANOVATest.scala index 482bb7fdc2105..2a3470e38f6ef 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/ANOVATest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/ANOVATest.scala @@ -69,20 +69,18 @@ private[ml] object ANOVATest { val spark = dataset.sparkSession import spark.implicits._ - val resultDF = testClassification(dataset, featuresCol, labelCol) - .toDF("featureIndex", "pValue", "degreesOfFreedom", "fValue") + val resRdd = testClassification(dataset, featuresCol, labelCol) if (flatten) { - resultDF + resRdd.toDF("featureIndex", "pValue", "degreesOfFreedom", "fValue") } else { - resultDF.agg(collect_list(struct("*"))) - .as[Seq[(Int, Double, Long, Double)]] - .map { seq => - val results = seq.toArray.sortBy(_._1) - val pValues = Vectors.dense(results.map(_._2)) - val degreesOfFreedom = results.map(_._3) - val fValues = Vectors.dense(results.map(_._4)) - (pValues, degreesOfFreedom, fValues) + resRdd.coalesce(1) + .mapPartitions { iter => + val res = iter.toArray.sortBy(_._1) + val pValues = Vectors.dense(res.map(_._2)) + val degreesOfFreedom = res.map(_._3) + val fValues = Vectors.dense(res.map(_._4)) + Iterator.single((pValues, degreesOfFreedom, fValues)) }.toDF("pValues", "degreesOfFreedom", "fValues") } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala index 2207c20049a1e..cdbfb6090acf5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala @@ -80,22 +80,21 @@ object ChiSquareTest { val data = dataset.select(col(labelCol).cast("double"), col(featuresCol)).rdd .map { case Row(label: Double, vec: Vector) => (label, OldVectors.fromML(vec)) } - val resultDF = OldChiSqTest.computeChiSquared(data) + val resRDD = OldChiSqTest.computeChiSquared(data) .map { case (col, pValue, degreesOfFreedom, statistic, _) => (col, pValue, degreesOfFreedom, statistic) - }.toDF("featureIndex", "pValue", "degreesOfFreedom", "statistic") + } if (flatten) { - resultDF + resRDD.toDF("featureIndex", "pValue", "degreesOfFreedom", "statistic") } else { - resultDF.agg(collect_list(struct("*"))) - .as[Seq[(Int, Double, Int, Double)]] - .map { seq => - val results = seq.toArray.sortBy(_._1) - val pValues = Vectors.dense(results.map(_._2)) - val degreesOfFreedom = results.map(_._3) - val statistics = Vectors.dense(results.map(_._4)) - (pValues, degreesOfFreedom, statistics) + resRDD.coalesce(1) + .mapPartitions { iter => + val res = iter.toArray.sortBy(_._1) + val pValues = Vectors.dense(res.map(_._2)) + val degreesOfFreedom = res.map(_._3) + val statistics = Vectors.dense(res.map(_._4)) + Iterator.single((pValues, degreesOfFreedom, statistics)) }.toDF("pValues", "degreesOfFreedom", "statistics") } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/FValueTest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/FValueTest.scala index e2ce6cf7214f7..56b7c058a5379 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/FValueTest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/FValueTest.scala @@ -70,20 +70,21 @@ private[ml] object FValueTest { val spark = dataset.sparkSession import spark.implicits._ + val resRDD = testRegression(dataset, featuresCol, labelCol) + val resultDF = testRegression(dataset, featuresCol, labelCol) .toDF("featureIndex", "pValue", "degreesOfFreedom", "fValue") if (flatten) { - resultDF + resRDD.toDF("featureIndex", "pValue", "degreesOfFreedom", "fValue") } else { - resultDF.agg(collect_list(struct("*"))) - .as[Seq[(Int, Double, Long, Double)]] - .map { seq => - val results = seq.toArray.sortBy(_._1) - val pValues = Vectors.dense(results.map(_._2)) - val degreesOfFreedom = results.map(_._3) - val fValues = Vectors.dense(results.map(_._4)) - (pValues, degreesOfFreedom, fValues) + resRDD.coalesce(1) + .mapPartitions { iter => + val res = iter.toArray.sortBy(_._1) + val pValues = Vectors.dense(res.map(_._2)) + val degreesOfFreedom = res.map(_._3) + val fValues = Vectors.dense(res.map(_._4)) + Iterator.single((pValues, degreesOfFreedom, fValues)) }.toDF("pValues", "degreesOfFreedom", "fValues") } } diff --git a/pom.xml b/pom.xml index 9947e186d2ee1..94e170149a7fd 100644 --- a/pom.xml +++ b/pom.xml @@ -120,7 +120,7 @@ 3.5.0 spark 9.7.1 - 2.0.16 + 2.0.17 2.24.3 3.4.1 @@ -197,6 +197,7 @@ 4.1.17 33.4.0-jre 2.11.0 + 4.0.7 3.1.9 3.0.16 2.13.0 @@ -214,7 +215,7 @@ 1.80 1.16.0 6.1.1 - 4.1.118.Final + 4.1.119.Final 2.0.70.Final 76.1 5.11.4 @@ -724,7 +725,7 @@ org.rocksdb rocksdbjni - 9.10.0 + 9.8.4 ${leveldbjni.group} @@ -860,6 +861,56 @@ ${protobuf.version} provided + + io.grpc + grpc-api + ${io.grpc.version} + + + io.grpc + grpc-protobuf + ${io.grpc.version} + + + io.grpc + grpc-stub + ${io.grpc.version} + + + io.grpc + grpc-netty + ${io.grpc.version} + + + io.grpc + grpc-services + ${io.grpc.version} + + + io.grpc + grpc-inprocess + ${io.grpc.version} + + + com.google.api.grpc + proto-google-common-protos + 2.41.0 + + + io.netty + netty-codec-http2 + ${netty.version} + + + io.netty + netty-handler-proxy + ${netty.version} + + + io.netty + netty-transport-native-unix-common + ${netty.version} + org.roaringbitmap RoaringBitmap @@ -1117,7 +1168,7 @@ org.json4s json4s-jackson_${scala.binary.version} - 4.0.7 + ${json4s.version} com.fasterxml.jackson.core @@ -1125,6 +1176,16 @@ + + org.json4s + json4s-core_${scala.binary.version} + ${json4s.version} + + + org.json4s + json4s-ast_${scala.binary.version} + ${json4s.version} + org.scala-lang.modules scala-xml_${scala.binary.version} @@ -2495,6 +2556,11 @@ + + org.apache.arrow + arrow-format + ${arrow.version} + org.apache.arrow arrow-vector @@ -2514,6 +2580,11 @@ + + org.apache.arrow + arrow-memory-core + ${arrow.version} + org.apache.arrow arrow-memory-netty diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 85c5474205d37..7ea894e5efcaa 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -1003,7 +1003,7 @@ object KubernetesIntegrationTests { rDockerFile = "" } val extraOptions = if (javaImageTag.isDefined) { - Seq("-b", s"java_image_tag=$javaImageTag") + Seq("-b", s"java_image_tag=${javaImageTag.get}") } else { Seq("-f", s"$dockerFile") } diff --git a/python/docs/source/index.rst b/python/docs/source/index.rst index 72a846290fe9e..2e102c8de71e7 100644 --- a/python/docs/source/index.rst +++ b/python/docs/source/index.rst @@ -36,6 +36,18 @@ to enable processing and analysis of data at any size for everyone familiar with PySpark supports all of Spark's features such as Spark SQL, DataFrames, Structured Streaming, Machine Learning (MLlib) and Spark Core. +.. list-table:: + :widths: 10 80 10 + :header-rows: 0 + :class: borderless spec_table + + * - + - .. image:: ../../../docs/img/pyspark-python_spark_connect_client.png + :target: getting_started/quickstart_connect.html + :width: 100% + :alt: Python Spark Connect Client + - + .. list-table:: :widths: 10 20 20 20 20 10 :header-rows: 0 @@ -72,6 +84,19 @@ DataFrames, Structured Streaming, Machine Learning (MLlib) and Spark Core. :alt: Spark Core and RDDs - +.. _Index Page - Python Spark Connect Client: + +**Python Spark Connect Client** + +Spark Connect is a client-server architecture within Apache Spark that +enables remote connectivity to Spark clusters from any application. +PySpark provides the client for the Spark Connect server, allowing +Spark to be used as a service. + +- :ref:`/getting_started/quickstart_connect.ipynb` +- |binder_connect|_ +- `Spark Connect Overview `_ + .. _Index Page - Spark SQL and DataFrames: **Spark SQL and DataFrames** diff --git a/python/docs/source/migration_guide/pyspark_upgrade.rst b/python/docs/source/migration_guide/pyspark_upgrade.rst index 976aef7cb68b6..906e40140ceff 100644 --- a/python/docs/source/migration_guide/pyspark_upgrade.rst +++ b/python/docs/source/migration_guide/pyspark_upgrade.rst @@ -75,8 +75,6 @@ Upgrading from PySpark 3.5 to 4.0 * In Spark 4.0, ``compute.ops_on_diff_frames`` is on by default. To restore the previous behavior, set ``compute.ops_on_diff_frames`` to ``false``. * In Spark 4.0, the data type ``YearMonthIntervalType`` in ``DataFrame.collect`` no longer returns the underlying integers. To restore the previous behavior, set ``PYSPARK_YM_INTERVAL_LEGACY`` environment variable to ``1``. * In Spark 4.0, items other than functions (e.g. ``DataFrame``, ``Column``, ``StructType``) have been removed from the wildcard import ``from pyspark.sql.functions import *``, you should import these items from proper modules (e.g. ``from pyspark.sql import DataFrame, Column``, ``from pyspark.sql.types import StructType``). -* In Spark 4.0, ``spark.sql.execution.pythonUDF.arrow.enabled`` is enabled by default. If users have PyArrow and pandas installed in their local and Spark Cluster, it automatically optimizes the regular Python UDFs with Arrow. To turn off the Arrow optimization, set ``spark.sql.execution.pythonUDF.arrow.enabled`` to ``false``. -* In Spark 4.0, ``spark.sql.execution.arrow.pyspark.enabled`` is enabled by default. If users have PyArrow and pandas installed in their local and Spark Cluster, it automatically makes use of Apache Arrow for columnar data transfers in PySpark. This optimization applies to ``pyspark.sql.DataFrame.toPandas`` and ``pyspark.sql.SparkSession.createDataFrame`` when its input is a Pandas DataFrame or a NumPy ndarray. To turn off the Arrow optimization, set ``spark.sql.execution.arrow.pyspark.enabled`` to ``false``. Upgrading from PySpark 3.3 to 3.4 diff --git a/python/docs/source/user_guide/pandas_on_spark/options.rst b/python/docs/source/user_guide/pandas_on_spark/options.rst index e8fffea7e33be..31f3cff266de2 100644 --- a/python/docs/source/user_guide/pandas_on_spark/options.rst +++ b/python/docs/source/user_guide/pandas_on_spark/options.rst @@ -208,6 +208,16 @@ This is conceptually equivalent to the PySpark example as below: >>> spark_df.rdd.zipWithIndex().map(lambda p: p[1]).collect() [0, 1, 2] +.. warning:: + Unlike `sequence`, since `distributed-sequence` is executed in a distributed environment, + the rows corresponding to each index may vary although the index itself is still + remains globally sequential. + This happens because the rows are distributed across multiple partitions and nodes, + leading to indeterministic row-to-index mappings when the data is loaded. + Therefore, it is recommended to explicitly set an index column by using `index_col` parameter + instead of relying on the default index when creating `DataFrame` + if the row-to-index mapping is critical for your application. + **distributed**: It implements a monotonically increasing sequence simply by using PySpark's `monotonically_increasing_id` function in a fully distributed manner. The values are indeterministic. If the index does not have to be a sequence that increases diff --git a/python/docs/source/user_guide/sql/arrow_pandas.rst b/python/docs/source/user_guide/sql/arrow_pandas.rst index b9e389f8fe7dd..fde40140110f9 100644 --- a/python/docs/source/user_guide/sql/arrow_pandas.rst +++ b/python/docs/source/user_guide/sql/arrow_pandas.rst @@ -356,8 +356,8 @@ Arrow Python UDFs are user defined functions that are executed row-by-row, utili transfer and serialization. To define an Arrow Python UDF, you can use the :meth:`udf` decorator or wrap the function with the :meth:`udf` method, ensuring the ``useArrow`` parameter is set to True. Additionally, you can enable Arrow optimization for Python UDFs throughout the entire SparkSession by setting the Spark configuration -``spark.sql.execution.pythonUDF.arrow.enabled`` to true, which is the default. It's important to note that the Spark -configuration takes effect only when ``useArrow`` is either not set or set to None. +``spark.sql.execution.pythonUDF.arrow.enabled`` to true. It's important to note that the Spark configuration takes +effect only when ``useArrow`` is either not set or set to None. The type hints for Arrow Python UDFs should be specified in the same way as for default, pickled Python UDFs. diff --git a/python/docs/source/user_guide/sql/type_conversions.rst b/python/docs/source/user_guide/sql/type_conversions.rst index 80f8aa83db7eb..2f13701995ef2 100644 --- a/python/docs/source/user_guide/sql/type_conversions.rst +++ b/python/docs/source/user_guide/sql/type_conversions.rst @@ -57,7 +57,7 @@ are listed below: - Default * - spark.sql.execution.pythonUDF.arrow.enabled - Enable PyArrow in PySpark. See more `here `_. - - True + - False * - spark.sql.pyspark.inferNestedDictAsStruct.enabled - When enabled, nested dictionaries are inferred as StructType. Otherwise, they are inferred as MapType. - False diff --git a/python/packaging/classic/setup.py b/python/packaging/classic/setup.py index 8cc2eb182f8f3..da4d25cc908c0 100755 --- a/python/packaging/classic/setup.py +++ b/python/packaging/classic/setup.py @@ -291,6 +291,7 @@ def run(self): "pyspark.sql.plot", "pyspark.sql.protobuf", "pyspark.sql.streaming", + "pyspark.sql.streaming.proto", "pyspark.sql.worker", "pyspark.streaming", "pyspark.bin", diff --git a/python/packaging/client/setup.py b/python/packaging/client/setup.py index fd4beba29b76d..8631bf9f8e593 100755 --- a/python/packaging/client/setup.py +++ b/python/packaging/client/setup.py @@ -168,6 +168,7 @@ "pyspark.sql.plot", "pyspark.sql.protobuf", "pyspark.sql.streaming", + "pyspark.sql.streaming.proto", "pyspark.sql.worker", "pyspark.streaming", "pyspark.pandas", diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py index 0bdfa27fc7021..224ef34fd5edc 100644 --- a/python/pyspark/ml/base.py +++ b/python/pyspark/ml/base.py @@ -328,8 +328,7 @@ def transformSchema(self, schema: StructType) -> StructType: def _transform(self, dataset: DataFrame) -> DataFrame: self.transformSchema(dataset.schema) - # TODO(SPARK-48515): Use Arrow Python UDF - transformUDF = udf(self.createTransformFunc(), self.outputDataType(), useArrow=False) + transformUDF = udf(self.createTransformFunc(), self.outputDataType()) transformedDataset = dataset.withColumn( self.getOutputCol(), transformUDF(dataset[self.getInputCol()]) ) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index a9892dc8db36d..3f1f35c2e3f81 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -908,12 +908,8 @@ def checker(foldNum: int) -> bool: from pyspark.sql.connect.udf import UserDefinedFunction else: from pyspark.sql.functions import UserDefinedFunction # type: ignore[assignment] - from pyspark.util import PythonEvalType - # TODO(SPARK-48515): Use Arrow Python UDF - checker_udf = UserDefinedFunction( - checker, BooleanType(), evalType=PythonEvalType.SQL_BATCHED_UDF - ) + checker_udf = UserDefinedFunction(checker, BooleanType()) for i in range(nFolds): training = dataset.filter(checker_udf(dataset[foldCol]) & (col(foldCol) != lit(i))) validation = dataset.filter( diff --git a/python/pyspark/pandas/base.py b/python/pyspark/pandas/base.py index 81358185a5608..01e23214d662d 100644 --- a/python/pyspark/pandas/base.py +++ b/python/pyspark/pandas/base.py @@ -1191,7 +1191,6 @@ def _shift( return self._with_new_scol(col, field=self._internal.data_fields[0].copy(nullable=True)) # TODO: Update Documentation for Bins Parameter when its supported - # TODO(SPARK-51287): Enable s.index.value_counts() tests def value_counts( self, normalize: bool = False, @@ -1324,7 +1323,7 @@ def value_counts( ('falcon', 'length')], ) - >>> s.index.value_counts().sort_index() # doctest: +SKIP + >>> s.index.value_counts().sort_index() (cow, length) 1 (cow, weight) 2 (falcon, length) 2 @@ -1332,7 +1331,7 @@ def value_counts( (lama, weight) 3 Name: count, dtype: int64 - >>> s.index.value_counts(normalize=True).sort_index() # doctest: +SKIP + >>> s.index.value_counts(normalize=True).sort_index() (cow, length) 0.111111 (cow, weight) 0.222222 (falcon, length) 0.222222 diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 360f391de6c1c..062778b6ddb2c 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -670,7 +670,7 @@ def __init__( self._channel = self._builder.toChannel() self._closed = False - self._stub = grpc_lib.SparkConnectServiceStub(self._channel) + self._internal_stub = grpc_lib.SparkConnectServiceStub(self._channel) self._artifact_manager = ArtifactManager( self._user_id, self._session_id, self._channel, self._builder.metadata() ) @@ -688,6 +688,19 @@ def __init__( # cleanup ml cache if possible atexit.register(self._cleanup_ml) + @property + def _stub(self) -> grpc_lib.SparkConnectServiceStub: + if self.is_closed: + raise SparkConnectException( + errorClass="NO_ACTIVE_SESSION", messageParameters=dict() + ) from None + return self._internal_stub + + # For testing only. + @_stub.setter + def _stub(self, value: grpc_lib.SparkConnectServiceStub) -> None: + self._internal_stub = value + def register_progress_handler(self, handler: ProgressHandler) -> None: """ Register a progress handler to be called when a progress message is received. @@ -1796,11 +1809,6 @@ def _handle_error(self, error: Exception) -> NoReturn: self.thread_local.inside_error_handling = True if isinstance(error, grpc.RpcError): self._handle_rpc_error(error) - elif isinstance(error, ValueError): - if "Cannot invoke RPC" in str(error) and "closed" in str(error): - raise SparkConnectException( - errorClass="NO_ACTIVE_SESSION", messageParameters=dict() - ) from None raise error finally: self.thread_local.inside_error_handling = False diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py index 6045e441222de..30bb6ed9b29d0 100644 --- a/python/pyspark/sql/connect/udf.py +++ b/python/pyspark/sql/connect/udf.py @@ -23,8 +23,6 @@ import sys import functools -import warnings -from inspect import getfullargspec from typing import cast, Callable, Any, List, TYPE_CHECKING, Optional, Union from pyspark.util import PythonEvalType @@ -41,7 +39,6 @@ UDFRegistration as PySparkUDFRegistration, UserDefinedFunction as PySparkUserDefinedFunction, ) -from pyspark.sql.utils import has_arrow from pyspark.errors import PySparkTypeError, PySparkRuntimeError if TYPE_CHECKING: @@ -59,7 +56,6 @@ def _create_py_udf( returnType: "DataTypeOrString", useArrow: Optional[bool] = None, ) -> "UserDefinedFunctionLike": - is_arrow_enabled = False if useArrow is None: is_arrow_enabled = False try: @@ -80,27 +76,8 @@ def _create_py_udf( eval_type: int = PythonEvalType.SQL_BATCHED_UDF - if is_arrow_enabled and not has_arrow: - is_arrow_enabled = False - warnings.warn( - "Arrow optimization failed to enable because PyArrow is not installed. " - "Falling back to a non-Arrow-optimized UDF.", - RuntimeWarning, - ) - if is_arrow_enabled: - try: - is_func_with_args = len(getfullargspec(f).args) > 0 - except TypeError: - is_func_with_args = False - if is_func_with_args: - eval_type = PythonEvalType.SQL_ARROW_BATCHED_UDF - else: - warnings.warn( - "Arrow optimization for Python UDFs cannot be enabled for functions" - " without arguments.", - UserWarning, - ) + eval_type = PythonEvalType.SQL_ARROW_BATCHED_UDF return _create_udf(f, returnType, eval_type) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 34cf38bafdc68..32e8422129dbc 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -26419,8 +26419,7 @@ def udf( Defaults to :class:`StringType`. useArrow : bool, optional whether to use Arrow to optimize the (de)serialization. When it is None, the - Spark config "spark.sql.execution.pythonUDF.arrow.enabled" takes effect, - which is "true" by default. + Spark config "spark.sql.execution.pythonUDF.arrow.enabled" takes effect. Examples -------- diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 74d9a2ce65608..23b43c2bf8200 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1223,8 +1223,17 @@ def dump_stream(self, iterator, stream): Read through an iterator of (iterator of pandas DataFrame), serialize them to Arrow RecordBatches, and write batches to stream. """ - result = [(b, t) for x in iterator for y, t in x for b in y] - super().dump_stream(result, stream) + + def flatten_iterator(): + # iterator: iter[list[(iter[pandas.DataFrame], pdf_type)]] + for packed in iterator: + iter_pdf_with_type = packed[0] + iter_pdf = iter_pdf_with_type[0] + pdf_type = iter_pdf_with_type[1] + for pdf in iter_pdf: + yield (pdf, pdf_type) + + super().dump_stream(flatten_iterator(), stream) class TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSerializer): diff --git a/python/pyspark/sql/tests/connect/pandas/test_parity_pandas_transform_with_state.py b/python/pyspark/sql/tests/connect/pandas/test_parity_pandas_transform_with_state.py index fb25c448cef0a..334031ec362f1 100644 --- a/python/pyspark/sql/tests/connect/pandas/test_parity_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/connect/pandas/test_parity_pandas_transform_with_state.py @@ -48,6 +48,10 @@ def conf(cls): return cfg + @unittest.skip("Flaky in spark connect on CI. Skip for now. See SPARK-51368 for details.") + def test_schema_evolution_scenarios(self): + pass + if __name__ == "__main__": from pyspark.sql.tests.connect.pandas.test_parity_pandas_transform_with_state import * # noqa: F401,E501 diff --git a/python/pyspark/sql/tests/connect/test_connect_creation.py b/python/pyspark/sql/tests/connect/test_connect_creation.py index 163a1c17bfafa..5352913f6609d 100644 --- a/python/pyspark/sql/tests/connect/test_connect_creation.py +++ b/python/pyspark/sql/tests/connect/test_connect_creation.py @@ -219,11 +219,6 @@ def test_with_atom_type(self): self.assert_eq(sdf.toPandas(), cdf.toPandas()) def test_with_none_and_nan(self): - # TODO(SPARK-51286): Fix test_with_none_and_nan to to pass with Arrow enabled - with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): - self.check_with_none_and_nan() - - def check_with_none_and_nan(self): # SPARK-41855: make createDataFrame support None and NaN # SPARK-41814: test with eqNullSafe data1 = [Row(id=1, value=float("NaN")), Row(id=2, value=42.0), Row(id=3, value=None)] diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py b/python/pyspark/sql/tests/connect/test_connect_function.py index d1e2558305291..38ec0e29da7df 100644 --- a/python/pyspark/sql/tests/connect/test_connect_function.py +++ b/python/pyspark/sql/tests/connect/test_connect_function.py @@ -61,7 +61,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): cls.spark = cls.connect # Stopping Spark Connect closes the session in JVM at the server. - super(SparkConnectFunctionTests, cls).setUpClass() + super(SparkConnectFunctionTests, cls).tearDownClass() del os.environ["PYSPARK_NO_NAMESPACE_SHARE"] def compare_by_show(self, df1, df2, n: int = 20, truncate: int = 20): diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 79fbf46f005d4..6e02346f78bd7 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -34,7 +34,6 @@ ) from pyspark.sql.utils import get_active_spark_context from pyspark.sql.pandas.types import to_arrow_type -from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version from pyspark.errors import PySparkTypeError, PySparkNotImplementedError, PySparkRuntimeError if TYPE_CHECKING: @@ -116,7 +115,7 @@ def _create_py_udf( # Note: The values of 'SQL Type' are DDL formatted strings, which can be used as `returnType`s. # Note: The values inside the table are generated by `repr`. X' means it throws an exception # during the conversion. - is_arrow_enabled = False + if useArrow is None: from pyspark.sql import SparkSession @@ -129,18 +128,6 @@ def _create_py_udf( else: is_arrow_enabled = useArrow - if is_arrow_enabled: - try: - require_minimum_pandas_version() - require_minimum_pyarrow_version() - except ImportError: - is_arrow_enabled = False - warnings.warn( - "Arrow optimization failed to enable because PyArrow/pandas is not installed. " - "Falling back to a non-Arrow-optimized UDF.", - RuntimeWarning, - ) - eval_type: int = PythonEvalType.SQL_BATCHED_UDF if is_arrow_enabled: diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 63beda40dc52d..b0782d04cba3d 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -63,15 +63,6 @@ from pyspark.pandas._typing import IndexOpsLike, SeriesOrIndex -has_arrow: bool = False -try: - import pyarrow # noqa: F401 - - has_arrow = True -except ImportError: - pass - - FuncT = TypeVar("FuncT", bound=Callable[..., Any]) diff --git a/resource-managers/kubernetes/integration-tests/README.md b/resource-managers/kubernetes/integration-tests/README.md index 4b0ec2cd852fa..3511a64f7fe5b 100644 --- a/resource-managers/kubernetes/integration-tests/README.md +++ b/resource-managers/kubernetes/integration-tests/README.md @@ -199,9 +199,9 @@ to the wrapper scripts and using the wrapper scripts will simply set these appro spark.kubernetes.test.javaImageTag - A specific OpenJDK base image tag to use, when set uses it instead of azul/zulu-openjdk. + A specific Azul Zulu OpenJDK base image tag to use, when set uses it instead of 21. - azul/zulu-openjdk + N/A spark.kubernetes.test.imageTagFile diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala index d998502ac1b25..439b01a2eb060 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala @@ -276,11 +276,14 @@ object AgnosticEncoders { * another encoder. This is fallback for scenarios where objects can't be represented using * standard encoders, an example of this is where we use a different (opaque) serialization * format (i.e. java serialization, kryo serialization, or protobuf). + * @param nullable + * defaults to false indicating the codec guarantees decode / encode results are non-nullable */ case class TransformingEncoder[I, O]( clsTag: ClassTag[I], transformed: AgnosticEncoder[O], - codecProvider: () => Codec[_ >: I, O]) + codecProvider: () => Codec[_ >: I, O], + override val nullable: Boolean = false) extends AgnosticEncoder[I] { override def isPrimitive: Boolean = transformed.isPrimitive override def dataType: DataType = transformed.dataType diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeMode.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeMode.scala index e870a83ec4ae6..da454c1c4214e 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeMode.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeMode.scala @@ -21,7 +21,13 @@ import java.util.Locale import org.apache.spark.SparkIllegalArgumentException import org.apache.spark.sql.streaming.TimeMode -/** TimeMode types used in transformWithState operator */ +/** + * TimeMode types used in transformWithState operator + * + * Note that we need to keep TimeMode.None() named as "NoTime" in case class here because a case + * class named "None" will introduce naming collision with scala native type None. See SPARK-51151 + * for more info. + */ case object NoTime extends TimeMode case object ProcessingTime extends TimeMode @@ -31,7 +37,7 @@ case object EventTime extends TimeMode object TimeModes { def apply(timeMode: String): TimeMode = { timeMode.toLowerCase(Locale.ROOT) match { - case "none" => + case "none" | "notime" => NoTime case "processingtime" => ProcessingTime diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala index c69c5bfb52616..664501d22cb36 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala @@ -264,4 +264,11 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { errorClass = "_LEGACY_ERROR_TEMP_1189", messageParameters = Map("operation" -> operation)) } + + def unsupportedTimePrecisionError(precision: Int): Throwable = { + new SparkException( + errorClass = "UNSUPPORTED_TIME_PRECISION", + messageParameters = Map("precision" -> precision.toString), + cause = null) + } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/TimeType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/TimeType.scala new file mode 100644 index 0000000000000..873c63472acf0 --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/TimeType.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.apache.spark.annotation.Unstable +import org.apache.spark.sql.errors.DataTypeErrors + +/** + * The time type represents a time value with fields hour, minute, second, up to microseconds. The + * range of times supported is 00:00:00.000000 to 23:59:59.999999. + * + * @param precision + * The time fractional seconds precision which indicates the number of decimal digits maintained + * following the decimal point in the seconds value. The supported range is [0, 6]. + * + * @since 4.1.0 + */ +@Unstable +case class TimeType(precision: Int) extends DatetimeType { + + if (precision < 0 || precision > 6) { + throw DataTypeErrors.unsupportedTimePrecisionError(precision) + } + + /** + * The default size of a value of the TimeType is 8 bytes. + */ + override def defaultSize: Int = 8 + + override def typeName: String = s"time($precision)" + + private[spark] override def asNullable: TimeType = this +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index fc477d1bc5ef5..492ea741236ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.{expressions => exprs} import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue} -import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec, JavaSerializationCodec, KryoSerializationCodec} +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, AgnosticExpressionPathEncoder, Codec, JavaSerializationCodec, KryoSerializationCodec} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedLeafEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, VarcharEncoder, YearMonthIntervalEncoder} import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder} import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, IsNull, Literal, MapKeys, MapValues, UpCast} @@ -270,6 +270,8 @@ object DeserializerBuildHelper { enc: AgnosticEncoder[_], path: Expression, walkedTypePath: WalkedTypePath): Expression = enc match { + case ae: AgnosticExpressionPathEncoder[_] => + ae.fromCatalyst(path) case _ if isNativeEncoder(enc) => path case _: BoxedLeafEncoder[_, _] => @@ -447,13 +449,13 @@ object DeserializerBuildHelper { val result = InitializeJavaBean(newInstance, setters.toMap) exprs.If(IsNull(path), exprs.Literal.create(null, ObjectType(cls)), result) - case TransformingEncoder(tag, _, codec) if codec == JavaSerializationCodec => + case TransformingEncoder(tag, _, codec, _) if codec == JavaSerializationCodec => DecodeUsingSerializer(path, tag, kryo = false) - case TransformingEncoder(tag, _, codec) if codec == KryoSerializationCodec => + case TransformingEncoder(tag, _, codec, _) if codec == KryoSerializationCodec => DecodeUsingSerializer(path, tag, kryo = true) - case TransformingEncoder(tag, encoder, provider) => + case TransformingEncoder(tag, encoder, provider, _) => Invoke( Literal.create(provider(), ObjectType(classOf[Codec[_, _]])), "decode", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala index 089d463ecacbb..a9e20c1eb3e70 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala @@ -21,7 +21,7 @@ import scala.language.existentials import org.apache.spark.sql.catalyst.{expressions => exprs} import org.apache.spark.sql.catalyst.DeserializerBuildHelper.expressionWithNullSafety -import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec, JavaSerializationCodec, KryoSerializationCodec} +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, AgnosticExpressionPathEncoder, Codec, JavaSerializationCodec, KryoSerializationCodec} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLeafEncoder, BoxedLongEncoder, BoxedShortEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveLeafEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, VarcharEncoder, YearMonthIntervalEncoder} import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder, lenientExternalDataTypeFor} import org.apache.spark.sql.catalyst.expressions.{BoundReference, CheckOverflow, CreateNamedStruct, Expression, IsNull, KnownNotNull, Literal, UnsafeArrayData} @@ -306,6 +306,7 @@ object SerializerBuildHelper { * by encoder `enc`. */ private def createSerializer(enc: AgnosticEncoder[_], input: Expression): Expression = enc match { + case ae: AgnosticExpressionPathEncoder[_] => ae.toCatalyst(input) case _ if isNativeEncoder(enc) => input case BoxedBooleanEncoder => createSerializerForBoolean(input) case BoxedByteEncoder => createSerializerForByte(input) @@ -418,18 +419,21 @@ object SerializerBuildHelper { } createSerializerForObject(input, serializedFields) - case TransformingEncoder(_, _, codec) if codec == JavaSerializationCodec => + case TransformingEncoder(_, _, codec, _) if codec == JavaSerializationCodec => EncodeUsingSerializer(input, kryo = false) - case TransformingEncoder(_, _, codec) if codec == KryoSerializationCodec => + case TransformingEncoder(_, _, codec, _) if codec == KryoSerializationCodec => EncodeUsingSerializer(input, kryo = true) - case TransformingEncoder(_, encoder, codecProvider) => + case TransformingEncoder(_, encoder, codecProvider, _) => val encoded = Invoke( Literal(codecProvider(), ObjectType(classOf[Codec[_, _]])), "encode", externalDataTypeFor(encoder), - input :: Nil) + input :: Nil, + propagateNull = input.nullable, + returnNullable = input.nullable + ) createSerializer(encoder, encoded) } @@ -486,6 +490,7 @@ object SerializerBuildHelper { nullable: Boolean): Expression => Expression = { input => val expected = enc match { case OptionEncoder(_) => lenientExternalDataTypeFor(enc) + case TransformingEncoder(_, transformed, _, _) => lenientExternalDataTypeFor(transformed) case _ => enc.dataType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala index 8398fb8d1e830..752a2a648ce99 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala @@ -22,9 +22,12 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, AttributeReference, AttributeSet, Expression, NamedExpression, OuterReference, SubqueryExpression} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.catalyst.trees.TreePattern._ object DeduplicateRelations extends Rule[LogicalPlan] { + val PROJECT_FOR_EXPRESSION_ID_DEDUPLICATION = + TreeNodeTag[Unit]("project_for_expression_id_deduplication") type ExprIdMap = mutable.HashMap[Class[_], mutable.HashSet[Long]] @@ -67,7 +70,9 @@ object DeduplicateRelations extends Rule[LogicalPlan] { val projectList = child.output.map { attr => Alias(attr, attr.name)() } - Project(projectList, child) + val project = Project(projectList, child) + project.setTagValue(DeduplicateRelations.PROJECT_FOR_EXPRESSION_ID_DEDUPLICATION, ()) + project } } u.copy(children = newChildren) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionBase.scala index 294ee93a3c7bc..eae7d5a74dbc2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionBase.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionBase.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{ Unpivot } import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure import org.apache.spark.sql.types.DataType @@ -142,7 +143,9 @@ abstract class TypeCoercionBase extends TypeCoercionHelper { case s @ Except(left, right, isAll) if s.childrenResolved && left.output.length == right.output.length && !s.resolved => - val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) + val newChildren: Seq[LogicalPlan] = withOrigin(s.origin) { + buildNewChildrenWithWiderTypes(left :: right :: Nil) + } if (newChildren.isEmpty) { s -> Nil } else { @@ -154,7 +157,9 @@ abstract class TypeCoercionBase extends TypeCoercionHelper { case s @ Intersect(left, right, isAll) if s.childrenResolved && left.output.length == right.output.length && !s.resolved => - val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) + val newChildren: Seq[LogicalPlan] = withOrigin(s.origin) { + buildNewChildrenWithWiderTypes(left :: right :: Nil) + } if (newChildren.isEmpty) { s -> Nil } else { @@ -166,7 +171,9 @@ abstract class TypeCoercionBase extends TypeCoercionHelper { case s: Union if s.childrenResolved && !s.byName && s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => - val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children) + val newChildren: Seq[LogicalPlan] = withOrigin(s.origin) { + buildNewChildrenWithWiderTypes(s.children) + } if (newChildren.isEmpty) { s -> Nil } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnionResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnionResolver.scala index 0e4eed3c20f15..cfd81114b0300 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnionResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnionResolver.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.{ TypeCoercionBase } import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, ExprId} -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Project, Union} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.{DataType, MetadataBuilder} @@ -52,8 +52,6 @@ class UnionResolver(resolver: Resolver, expressionResolver: ExpressionResolver) * for partially resolved subtrees from DataFrame programs. * - Resolve each child in the context of a) New [[NameScope]] b) New [[ExpressionIdAssigner]] * mapping. Collect child outputs to coerce them later. - * - Perform projection-based expression ID deduplication if required. This is a hack to stay - * compatible with fixed-point [[Analyzer]]. * - Perform individual output deduplication to handle the distinct union case described in * [[performIndividualOutputExpressionIdDeduplication]] scaladoc. * - Validate that child outputs have same length or throw "NUM_COLUMNS_MISMATCH" otherwise. @@ -68,10 +66,10 @@ class UnionResolver(resolver: Resolver, expressionResolver: ExpressionResolver) * - Return the resolved [[Union]] with new children. */ override def resolve(unresolvedUnion: Union): Union = { - val (oldOutput, oldChildOutputs) = if (unresolvedUnion.resolved) { - (Some(unresolvedUnion.output), Some(unresolvedUnion.children.map(_.output))) + val oldOutput = if (unresolvedUnion.resolved) { + Some(unresolvedUnion.output) } else { - (None, None) + None } val (resolvedChildren, childOutputs) = unresolvedUnion.children.zipWithIndex.map { @@ -84,16 +82,10 @@ class UnionResolver(resolver: Resolver, expressionResolver: ExpressionResolver) } }.unzip - val (projectBasedDeduplicatedChildren, projectBasedDeduplicatedChildOutputs) = - performProjectionBasedExpressionIdDeduplication( - resolvedChildren, - childOutputs, - oldChildOutputs - ) val (deduplicatedChildren, deduplicatedChildOutputs) = performIndividualOutputExpressionIdDeduplication( - projectBasedDeduplicatedChildren, - projectBasedDeduplicatedChildOutputs + resolvedChildren, + childOutputs ) val (newChildren, newChildOutputs) = if (needToCoerceChildOutputs(deduplicatedChildOutputs)) { @@ -117,64 +109,6 @@ class UnionResolver(resolver: Resolver, expressionResolver: ExpressionResolver) unresolvedUnion.copy(children = newChildren) } - /** - * Fixed-point [[Analyzer]] uses [[DeduplicateRelations]] rule to handle duplicate expression IDs - * in multi-child operator outputs. For [[Union]]s it uses a "projection-based deduplication", - * i.e. places another [[Project]] operator with new [[Alias]]es on the right child if duplicate - * expression IDs detected. New [[Alias]] "covers" the original attribute with new expression ID. - * This is done for all child operators except [[LeafNode]]s. - * - * We don't need this operation in single-pass [[Resolver]], since we have - * [[ExpressionIdAssigner]] for expression ID deduplication, but perform it nevertheless to stay - * compatible with fixed-point [[Analyzer]]. Since new outputs are already deduplicated by - * [[ExpressionIdAssigner]], we check the _old_ outputs for duplicates and place a [[Project]] - * only if old outputs are available (i.e. we are dealing with a resolved subtree from - * DataFrame program). - */ - private def performProjectionBasedExpressionIdDeduplication( - children: Seq[LogicalPlan], - childOutputs: Seq[Seq[Attribute]], - oldChildOutputs: Option[Seq[Seq[Attribute]]] - ): (Seq[LogicalPlan], Seq[Seq[Attribute]]) = { - oldChildOutputs match { - case Some(oldChildOutputs) => - val oldExpressionIds = new HashSet[ExprId] - - children - .zip(childOutputs) - .zip(oldChildOutputs) - .map { - case ((child: LeafNode, output), _) => - (child, output) - - case ((child, output), oldOutput) => - val oldOutputExpressionIds = new HashSet[ExprId] - - val hasConflicting = oldOutput.exists { oldAttribute => - oldOutputExpressionIds.add(oldAttribute.exprId) - oldExpressionIds.contains(oldAttribute.exprId) - } - - if (hasConflicting) { - val newExpressions = output.map { attribute => - Alias(attribute, attribute.name)() - } - ( - Project(projectList = newExpressions, child = child), - newExpressions.map(_.toAttribute) - ) - } else { - oldExpressionIds.addAll(oldOutputExpressionIds) - - (child, output) - } - } - .unzip - case _ => - (children, childOutputs) - } - } - /** * Deduplicate expression IDs at the scope of each individual child output. This is necessary to * handle the following case: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 1cb3520d4e265..5c4e9d4bddc5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -156,7 +156,7 @@ case class CatalogStorageFormat( def toJsonLinkedHashMap: mutable.LinkedHashMap[String, JValue] = { val map = mutable.LinkedHashMap[String, JValue]() - locationUri.foreach(l => map += ("Location" -> JString(l.toString))) + locationUri.foreach(l => map += ("Location" -> JString(CatalogUtils.URIToString(l)))) serde.foreach(s => map += ("Serde Library" -> JString(s))) inputFormat.foreach(format => map += ("InputFormat" -> JString(format))) outputFormat.foreach(format => map += ("OutputFormat" -> JString(format))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala index 81743251bada9..8f717795605f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.encoders import scala.collection.Map +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, CalendarIntervalEncoder, NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, SparkDecimalEncoder, VariantEncoder} import org.apache.spark.sql.catalyst.expressions.Expression @@ -26,6 +27,30 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, DataType, DateType, DayTimeIntervalType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ObjectType, ShortType, StringType, StructType, TimestampNTZType, TimestampType, UserDefinedType, VariantType, YearMonthIntervalType} import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal} +/** + * :: DeveloperApi :: + * Extensible [[AgnosticEncoder]] providing conversion extension points over type T + * @tparam T over T + */ +@DeveloperApi +@deprecated("This trait is intended only as a migration tool and will be removed in 4.1") +trait AgnosticExpressionPathEncoder[T] + extends AgnosticEncoder[T] { + /** + * Converts from T to InternalRow + * @param input the starting input path + * @return + */ + def toCatalyst(input: Expression): Expression + + /** + * Converts from InternalRow to T + * @param inputPath path expression from InternalRow + * @return + */ + def fromCatalyst(inputPath: Expression): Expression +} + /** * Helper class for Generating [[ExpressionEncoder]]s. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index b92acfb5b0f3a..084be5a350459 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -24,6 +24,7 @@ import org.apache.spark.SparkRuntimeException import org.apache.spark.sql.{Encoder, Row} import org.apache.spark.sql.catalyst.{DeserializerBuildHelper, InternalRow, JavaTypeInference, ScalaReflection, SerializerBuildHelper} import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{OptionEncoder, TransformingEncoder} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.{Deserializer, Serializer} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull @@ -215,6 +216,13 @@ case class ExpressionEncoder[T]( StructField(s.name, s.dataType, s.nullable) }) + private def transformerOfOption(enc: AgnosticEncoder[_]): Boolean = + enc match { + case t: TransformingEncoder[_, _] => transformerOfOption(t.transformed) + case _: OptionEncoder[_] => true + case _ => false + } + /** * Returns true if the type `T` is serialized as a struct by `objSerializer`. */ @@ -228,7 +236,8 @@ case class ExpressionEncoder[T]( * returns true if `T` is serialized as struct and is not `Option` type. */ def isSerializedAsStructForTopLevel: Boolean = { - isSerializedAsStruct && !classOf[Option[_]].isAssignableFrom(clsTag.runtimeClass) + isSerializedAsStruct && !classOf[Option[_]].isAssignableFrom(clsTag.runtimeClass) && + !transformerOfOption(encoder) } // serializer expressions are used to encode an object to a row, while the object is usually an diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 67d9aff947cfa..4b182607f4da6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util.{DateTimeUtils, LegacyDateFormats, TimestampFormatter} import org.apache.spark.sql.catalyst.util.DateTimeConstants._ @@ -2104,11 +2105,13 @@ case class ParseToDate( ansiEnabled: Boolean = SQLConf.get.ansiEnabled) extends RuntimeReplaceable with ImplicitCastInputTypes with TimeZoneAwareExpression { - override lazy val replacement: Expression = format.map { f => - Cast(GetTimestamp(left, f, TimestampType, "try_to_date", timeZoneId, ansiEnabled), DateType, - timeZoneId, EvalMode.fromBoolean(ansiEnabled)) - }.getOrElse(Cast(left, DateType, timeZoneId, - EvalMode.fromBoolean(ansiEnabled))) // backwards compatibility + override lazy val replacement: Expression = withOrigin(origin) { + format.map { f => + Cast(GetTimestamp(left, f, TimestampType, "try_to_date", timeZoneId, ansiEnabled), DateType, + timeZoneId, EvalMode.fromBoolean(ansiEnabled)) + }.getOrElse(Cast(left, DateType, timeZoneId, + EvalMode.fromBoolean(ansiEnabled))) // backwards compatibility + } def this(left: Expression, format: Expression) = { this(left, Option(format)) @@ -2183,9 +2186,11 @@ case class ParseToTimestamp( failOnError: Boolean = SQLConf.get.ansiEnabled) extends RuntimeReplaceable with ImplicitCastInputTypes with TimeZoneAwareExpression { - override lazy val replacement: Expression = format.map { f => - GetTimestamp(left, f, dataType, "try_to_timestamp", timeZoneId, failOnError = failOnError) - }.getOrElse(Cast(left, dataType, timeZoneId, ansiEnabled = failOnError)) + override lazy val replacement: Expression = withOrigin(origin) { + format.map { f => + GetTimestamp(left, f, dataType, "try_to_timestamp", timeZoneId, failOnError = failOnError) + }.getOrElse(Cast(left, dataType, timeZoneId, ansiEnabled = failOnError)) + } def this(left: Expression, format: Expression) = { this(left, Option(format), SQLConf.get.timestampType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 523b7c88fc8ce..2f40c11b15cf4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -813,12 +813,15 @@ class AstBuilder extends DataTypeAstBuilder } /** - * Parameters used for writing query to a table: - * (table ident, options, tableColumnList, partitionKeys, ifPartitionNotExists, byName). + * Parameters used for writing query to a table. */ - type InsertTableParams = - (IdentifierReferenceContext, Option[OptionsClauseContext], Seq[String], - Map[String, Option[String]], Boolean, Boolean) + case class InsertTableParams( + relationCtx: IdentifierReferenceContext, + options: Option[OptionsClauseContext], + userSpecifiedCols: Seq[String], + partitionSpec: Map[String, Option[String]], + ifPartitionNotExists: Boolean, + byName: Boolean) /** * Parameters used for writing query to a directory: (isLocal, CatalogStorageFormat, provider). @@ -845,31 +848,36 @@ class AstBuilder extends DataTypeAstBuilder // 2. Write commands do not hold the table logical plan as a child, and we need to add // additional resolution code to resolve identifiers inside the write commands. case table: InsertIntoTableContext => - val (relationCtx, options, cols, partition, ifPartitionNotExists, byName) - = visitInsertIntoTable(table) - withIdentClause(relationCtx, Seq(query), (ident, otherPlans) => { + val insertParams = visitInsertIntoTable(table) + withIdentClause(insertParams.relationCtx, Seq(query), (ident, otherPlans) => { InsertIntoStatement( - createUnresolvedRelation(relationCtx, ident, options, Seq(TableWritePrivilege.INSERT)), - partition, - cols, - otherPlans.head, + table = createUnresolvedRelation( + ctx = insertParams.relationCtx, + ident = ident, + optionsClause = insertParams.options, + writePrivileges = Seq(TableWritePrivilege.INSERT)), + partitionSpec = insertParams.partitionSpec, + userSpecifiedCols = insertParams.userSpecifiedCols, + query = otherPlans.head, overwrite = false, - ifPartitionNotExists, - byName) + ifPartitionNotExists = insertParams.ifPartitionNotExists, + byName = insertParams.byName) }) case table: InsertOverwriteTableContext => - val (relationCtx, options, cols, partition, ifPartitionNotExists, byName) - = visitInsertOverwriteTable(table) - withIdentClause(relationCtx, Seq(query), (ident, otherPlans) => { + val insertParams = visitInsertOverwriteTable(table) + withIdentClause(insertParams.relationCtx, Seq(query), (ident, otherPlans) => { InsertIntoStatement( - createUnresolvedRelation(relationCtx, ident, options, - Seq(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE)), - partition, - cols, - otherPlans.head, + table = createUnresolvedRelation( + ctx = insertParams.relationCtx, + ident = ident, + optionsClause = insertParams.options, + writePrivileges = Seq(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE)), + partitionSpec = insertParams.partitionSpec, + userSpecifiedCols = insertParams.userSpecifiedCols, + query = otherPlans.head, overwrite = true, - ifPartitionNotExists, - byName) + ifPartitionNotExists = insertParams.ifPartitionNotExists, + byName = insertParams.byName) }) case ctx: InsertIntoReplaceWhereContext => val options = Option(ctx.optionsClause()) @@ -896,8 +904,9 @@ class AstBuilder extends DataTypeAstBuilder */ override def visitInsertIntoTable( ctx: InsertIntoTableContext): InsertTableParams = withOrigin(ctx) { - val cols = Option(ctx.identifierList()).map(visitIdentifierList).getOrElse(Nil) - val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty) + val userSpecifiedCols = + Option(ctx.identifierList()).map(visitIdentifierList).getOrElse(Nil) + val partitionSpec = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty) blockBang(ctx.errorCapturingNot()) @@ -905,8 +914,13 @@ class AstBuilder extends DataTypeAstBuilder invalidStatement("INSERT INTO ... IF NOT EXISTS", ctx) } - (ctx.identifierReference, Option(ctx.optionsClause()), cols, partitionKeys, false, - ctx.NAME() != null) + InsertTableParams( + relationCtx = ctx.identifierReference(), + options = Option(ctx.optionsClause()), + userSpecifiedCols = userSpecifiedCols, + partitionSpec = partitionSpec, + ifPartitionNotExists = false, + byName = ctx.NAME() != null) } /** @@ -915,19 +929,25 @@ class AstBuilder extends DataTypeAstBuilder override def visitInsertOverwriteTable( ctx: InsertOverwriteTableContext): InsertTableParams = withOrigin(ctx) { assert(ctx.OVERWRITE() != null) - val cols = Option(ctx.identifierList()).map(visitIdentifierList).getOrElse(Nil) - val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty) + val userSpecifiedCols = + Option(ctx.identifierList()).map(visitIdentifierList).getOrElse(Nil) + val partitionSpec = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty) blockBang(ctx.errorCapturingNot()) - val dynamicPartitionKeys: Map[String, Option[String]] = partitionKeys.filter(_._2.isEmpty) + val dynamicPartitionKeys: Map[String, Option[String]] = partitionSpec.filter(_._2.isEmpty) if (ctx.EXISTS != null && dynamicPartitionKeys.nonEmpty) { operationNotAllowed("IF NOT EXISTS with dynamic partitions: " + dynamicPartitionKeys.keys.mkString(", "), ctx) } - (ctx.identifierReference, Option(ctx.optionsClause()), cols, partitionKeys, - ctx.EXISTS() != null, ctx.NAME() != null) + InsertTableParams( + relationCtx = ctx.identifierReference, + options = Option(ctx.optionsClause()), + userSpecifiedCols = userSpecifiedCols, + partitionSpec = partitionSpec, + ifPartitionNotExists = ctx.EXISTS() != null, + byName = ctx.NAME() != null) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala index 62ef65eb11128..1651003dd7744 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans import java.util.HashMap -import org.apache.spark.sql.catalyst.analysis.GetViewColumnByNameAndOrdinal +import org.apache.spark.sql.catalyst.analysis.{DeduplicateRelations, GetViewColumnByNameAndOrdinal} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.ReplaceExpressions @@ -145,6 +145,11 @@ object NormalizePlan extends PredicateHelper { .sortBy(_.hashCode()) .reduce(And) Join(left, right, newJoinType, Some(newCondition), hint) + case project: Project + if project + .getTagValue(DeduplicateRelations.PROJECT_FOR_EXPRESSION_ID_DEDUPLICATION) + .isDefined => + project.child case Project(projectList, child) => val projList = projectList .map { e => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index fb1999148d606..7d279092668f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -840,20 +840,30 @@ object View { // For temporary view, we always use captured sql configs if (activeConf.useCurrentSQLConfigsForView && !isTempView) return activeConf - val sqlConf = new SQLConf() // We retain below configs from current session because they are not captured by view // as optimization configs but they are still needed during the view resolution. - // TODO: remove this `retainedConfigs` after the `RelationConversions` is moved to + // TODO: remove this `retainedHiveConfigs` after the `RelationConversions` is moved to // optimization phase. + val retainedHiveConfigs = Seq( + "spark.sql.hive.convertMetastoreParquet", + "spark.sql.hive.convertMetastoreOrc", + "spark.sql.hive.convertInsertingPartitionedTable", + "spark.sql.hive.convertInsertingUnpartitionedTable", + "spark.sql.hive.convertMetastoreCtas" + ) + + val retainedLoggingConfigs = Seq( + "spark.sql.planChangeLog.level", + "spark.sql.expressionTreeChangeLog.level" + ) + val retainedConfigs = activeConf.getAllConfs.filter { case (key, _) => - Seq( - "spark.sql.hive.convertMetastoreParquet", - "spark.sql.hive.convertMetastoreOrc", - "spark.sql.hive.convertInsertingPartitionedTable", - "spark.sql.hive.convertInsertingUnpartitionedTable", - "spark.sql.hive.convertMetastoreCtas" - ).contains(key) || key.startsWith("spark.sql.catalog.") + retainedHiveConfigs.contains(key) || retainedLoggingConfigs.contains(key) || key.startsWith( + "spark.sql.catalog." + ) } + + val sqlConf = new SQLConf() for ((k, v) <- configs ++ retainedConfigs) { sqlConf.settings.put(k, v) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7e161fb9b7abe..8767d00767aa6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3271,7 +3271,7 @@ object SQLConf { .doc("(Deprecated since Spark 3.0, please set 'spark.sql.execution.arrow.pyspark.enabled'.)") .version("2.3.0") .booleanConf - .createWithDefault(true) + .createWithDefault(false) val ARROW_PYSPARK_EXECUTION_ENABLED = buildConf("spark.sql.execution.arrow.pyspark.enabled") @@ -3512,7 +3512,7 @@ object SQLConf { "can only be enabled when the given function takes at least one argument.") .version("3.4.0") .booleanConf - .createWithDefault(true) + .createWithDefault(false) val PYTHON_UDF_ARROW_CONCURRENCY_LEVEL = buildConf("spark.sql.execution.pythonUDF.arrow.concurrency.level") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index bba784800976c..616c6d65636d7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -23,20 +23,21 @@ import java.util.Arrays import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.reflect.classTag +import scala.reflect.{classTag, ClassTag} import scala.reflect.runtime.universe.TypeTag import org.apache.spark.{SPARK_DOC_ROOT, SparkArithmeticException, SparkRuntimeException, SparkUnsupportedOperationException} import org.apache.spark.sql.{Encoder, Encoders, Row} -import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, OptionalData, PrimitiveData, ScroogeLikeExample} +import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, OptionalData, PrimitiveData, ScalaReflection, ScroogeLikeExample} import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, TransformingEncoder} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, EncoderField, IterableEncoder, MapEncoder, OptionEncoder, PrimitiveIntEncoder, ProductEncoder, TimestampEncoder, TransformingEncoder} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, NaNvl} import org.apache.spark.sql.catalyst.plans.CodegenInterpretedPlanTest import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.catalyst.util.SparkDateTimeUtils.{instantToMicros, microsToInstant} import org.apache.spark.sql.errors.QueryErrorsBase import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -142,6 +143,21 @@ case class OptionNestedGeneric[T](list: Option[T]) case class MapNestedGenericKey[T](list: Map[T, Int]) case class MapNestedGenericValue[T](list: Map[Int, T]) +// ADT encoding for TransformingEncoder test +trait Base { + def name: String +} + +case class A(name: String, number: Int) extends Base + +case class B(name: String, text: String) extends Base + +case class Struct(typ: String, name: String, number: Option[Int] = None, + text: Option[String] = None) +// end ADT encoding + +case class V[A](v: A) + class Wrapper[T](val value: T) { override def hashCode(): Int = value.hashCode() override def equals(obj: Any): Boolean = obj match { @@ -585,6 +601,7 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes encodeDecodeTest(FooEnum.E1, "scala Enum") + // TransformingEncoder tests ---------------------------------------------------------- private def testTransformingEncoder( name: String, @@ -592,7 +609,8 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes val encoder = ExpressionEncoder(TransformingEncoder( classTag[(Long, Long)], BinaryEncoder, - provider)) + provider, + nullable = true)) .resolveAndBind() assert(encoder.schema == new StructType().add("value", BinaryType)) val toRow = encoder.createSerializer() @@ -603,6 +621,32 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes testTransformingEncoder("transforming java serialization encoder", JavaSerializationCodec) testTransformingEncoder("transforming kryo encoder", KryoSerializationCodec) + test("transforming encoders ADT - Frameless Injections use case") { + val provider = () => new Codec[Base, Struct]{ + override def encode(in: Base): Struct = in match { + case A(name, number) => Struct("A", name, number = Some(number)) + case B(name, text) => Struct("B", name, text = Some(text)) + } + + override def decode(out: Struct): Base = out match { + case Struct("A", name, Some(number), None) => A(name, number) + case Struct("B", name, None, Some(text)) => B(name, text) + case _ => throw new Exception(f"Invalid Base structure {s}") + } + } + val encoder = ExpressionEncoder(TransformingEncoder( + classTag[Base], + ScalaReflection.encoderFor[Struct], + provider)) + .resolveAndBind() + + val toRow = encoder.createSerializer() + val fromRow = encoder.createDeserializer() + + assert(fromRow(toRow(A("anA", 1))) == A("anA", 1)) + assert(fromRow(toRow(B("aB", "text"))) == B("aB", "text")) + } + test("transforming row encoder") { val schema = new StructType().add("a", LongType).add("b", StringType) val encoder = ExpressionEncoder(TransformingEncoder( @@ -615,7 +659,141 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes assert(fromRow(toRow(new Wrapper(Row(9L, "x")))) == new Wrapper(Row(9L, "x"))) } + // below tests are related to SPARK-49960 and TransformingEncoder usage + test("""Encoder with OptionEncoder of transformation""".stripMargin) { + type T = Option[V[V[Int]]] + val data: Seq[T] = Seq(None, Some(V(V(1)))) + + /* attempt to behave as if value class semantics except the last product, + using a final transforming instead of a product serializes */ + val enc = + OptionEncoder( + transforming( + V_OF_INT, + true + ) + ) + + testDataTransformingEnc(enc, data) + } + def testDataTransformingEnc[T](enc: AgnosticEncoder[T], data: Seq[T]): Unit = { + val encoder = ExpressionEncoder[T](enc).resolveAndBind() + val toRow = encoder.createSerializer() + val fromRow = encoder.createDeserializer() + data.foreach{ row => + assert(fromRow(toRow(data.head)) === data.head) + } + } + + def provider[A]: () => Codec[V[A], A] = () => + new Codec[V[A], A]{ + override def encode(in: V[A]): A = in.v + override def decode(out: A): V[A] = if (out == null) null else V(out) + } + + def transforming[A](underlying: AgnosticEncoder[A], + useUnderyling: Boolean = false): TransformingEncoder[V[A], A] = + TransformingEncoder[V[A], A]( + implicitly[ClassTag[V[A]]], + underlying, + provider, + if (useUnderyling) { + underlying.nullable + } else { + false + } + ) + + val V_INT = StructType(Seq(StructField("v", IntegerType, nullable = true))) + + // product encoder for a non-nullable V + val V_OF_INT = + ProductEncoder( + classTag[V[Int]], + Seq(EncoderField("v", PrimitiveIntEncoder, nullable = false, Metadata.empty)), + None + ) + + test("""Encoder derivation with nested TransformingEncoder of OptionEncoder""".stripMargin) { + type T = V[V[Option[V[Int]]]] + val data: Seq[T] = Seq(V(V(None)), V(V(Some(V(1))))) + + /* attempt to behave as if value class semantics except the last product, + using a final transforming instead of a product serializes */ + val enc = + transforming( + transforming( + OptionEncoder( + V_OF_INT + ) + ) + ) + + testDataTransformingEnc(enc, data) + } + + test("""Encoder derivation with TransformingEncoder of OptionEncoder""".stripMargin) { + type T = V[Option[V[Int]]] + val data: Seq[T] = Seq(V(None), V(Some(V(1)))) + + /* attempt to behave as if value class semantics except the last product, + using a final transforming instead of a product serializes */ + val enc = + transforming( + OptionEncoder( + V_OF_INT + ) + ) + + testDataTransformingEnc(enc, data) + } + + val longEncForTimestamp: AgnosticEncoder[V[Long]] = + TransformingEncoder[V[Long], java.sql.Timestamp]( + classTag, + TimestampEncoder(true), + () => + new Codec[V[Long], java.sql.Timestamp] with Serializable { + override def encode(in: V[Long]): Timestamp = Timestamp.from(microsToInstant(in.v)) + + override def decode(out: Timestamp): V[Long] = V[Long](instantToMicros(out.toInstant)) + } + ) + + test("""TransformingEncoder as Iterable""".stripMargin) { + type T = Seq[V[Long]] + val data: Seq[T] = Seq(Seq(V(0)), Seq(V(1), V(2))) + + /* requires validateAndSerializeElement to test for TransformingEncoder */ + val enc: AgnosticEncoder[T] = + IterableEncoder[Seq[V[Long]], V[Long]]( + implicitly[ClassTag[Seq[V[Long]]]], + longEncForTimestamp, + containsNull = false, + lenientSerialization = false) + + assert(enc.dataType === new ArrayType(TimestampType, false)) + + testDataTransformingEnc(enc, data) + } + + test("""TransformingEncoder as Map Key/Value""".stripMargin) { + type T = Map[V[Long], V[Long]] + val data: Seq[T] = Seq(Map(V(0L) -> V(0L)), Map(V(1L) -> V(1L)), Map(V(2L) -> V(2L))) + + /* requires validateAndSerializeElement to test for TransformingEncoder */ + val enc: AgnosticEncoder[T] = + MapEncoder[T, V[Long], V[Long]]( + implicitly[ClassTag[T]], + longEncForTimestamp, + longEncForTimestamp, + valueContainsNull = false) + + assert(enc.dataType === new MapType(TimestampType, TimestampType, false)) + + testDataTransformingEnc(enc, data) + } // Scala / Java big decimals ---------------------------------------------------------- encodeDecodeTest(BigDecimal(("9" * 20) + "." + "9" * 18), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 397241be76eb1..5f8689ec37216 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -380,6 +380,8 @@ class DataTypeSuite extends SparkFunSuite { checkDefaultSize(VarcharType(10), 10) yearMonthIntervalTypes.foreach(checkDefaultSize(_, 4)) dayTimeIntervalTypes.foreach(checkDefaultSize(_, 8)) + checkDefaultSize(TimeType(0), 8) + checkDefaultSize(TimeType(6), 8) def checkEqualsIgnoreCompatibleNullability( from: DataType, @@ -1371,4 +1373,18 @@ class DataTypeSuite extends SparkFunSuite { } assert(exception.getMessage.contains("The length of varchar type cannot be negative.")) } + + test("precisions of the TIME data type") { + 0 to 6 foreach { p => assert(TimeType(p).sql == s"TIME($p)") } + + Seq(Int.MinValue, -1, 7, Int.MaxValue).foreach { p => + checkError( + exception = intercept[SparkException] { + TimeType(p) + }, + condition = "UNSUPPORTED_TIME_PRECISION", + parameters = Map("precision" -> p.toString) + ) + } + } } diff --git a/sql/connect/client/jvm/pom.xml b/sql/connect/client/jvm/pom.xml index 81e195fb5ccde..11eb65add2016 100644 --- a/sql/connect/client/jvm/pom.xml +++ b/sql/connect/client/jvm/pom.xml @@ -75,10 +75,8 @@ compile - com.google.guava - guava - ${connect.guava.version} - compile + io.grpc + grpc-netty com.google.guava @@ -133,6 +131,16 @@ ${mima.version} test + + com.google.protobuf + protobuf-java-util + test + + + io.grpc + grpc-inprocess + test + target/scala-${scala.binary.version}/classes @@ -148,7 +156,6 @@ true - com.google.guava:* com.google.android:* com.google.api.grpc:* com.google.code.findbugs:* @@ -168,13 +175,6 @@ - - com.google.common - ${spark.shade.packageName}.connect.guava - - com.google.common.** - - io.grpc ${spark.shade.packageName}.io.grpc @@ -185,10 +185,6 @@ com.google ${spark.shade.packageName}.com.google - - - com.google.common.** - io.netty diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/TransformWithStateConnectSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/TransformWithStateConnectSuite.scala new file mode 100644 index 0000000000000..310b50dac1cc3 --- /dev/null +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/TransformWithStateConnectSuite.scala @@ -0,0 +1,522 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.streaming + +import java.io.{BufferedWriter, File, FileWriter} +import java.nio.file.Paths +import java.sql.Timestamp + +import org.scalatest.BeforeAndAfterEach +import org.scalatest.concurrent.Eventually.eventually +import org.scalatest.concurrent.Futures.timeout +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{DataFrame, Dataset, Encoders, Row} +import org.apache.spark.sql.connect.SparkSession +import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.streaming.{ListState, MapState, OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeMode, TimerValues, TTLConfig, ValueState} +import org.apache.spark.sql.types._ + +case class InputRowForConnectTest(key: String, value: String) +case class OutputRowForConnectTest(key: String, value: String) +case class StateRowForConnectTest(count: Long) + +// A basic stateful processor which will return the occurrences of key +class BasicCountStatefulProcessor + extends StatefulProcessor[String, InputRowForConnectTest, OutputRowForConnectTest] + with Logging { + @transient protected var _countState: ValueState[StateRowForConnectTest] = _ + + override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = { + _countState = getHandle.getValueState[StateRowForConnectTest]( + "countState", + Encoders.product[StateRowForConnectTest], + TTLConfig.NONE) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[InputRowForConnectTest], + timerValues: TimerValues): Iterator[OutputRowForConnectTest] = { + val count = inputRows.toSeq.length + { + if (_countState.exists()) { + _countState.get().count + } else { + 0L + } + } + _countState.update(StateRowForConnectTest(count)) + Iterator(OutputRowForConnectTest(key, count.toString)) + } +} + +// A stateful processor with initial state which will return the occurrences of key +class TestInitialStatefulProcessor + extends StatefulProcessorWithInitialState[ + String, + (String, String), + (String, String), + (String, String, String)] + with Logging { + @transient protected var _countState: ValueState[Long] = _ + + override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = { + _countState = getHandle.getValueState[Long]("countState", Encoders.scalaLong, TTLConfig.NONE) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[(String, String)], + timerValues: TimerValues): Iterator[(String, String)] = { + val count = inputRows.toSeq.length + { + if (_countState.exists()) { + _countState.get() + } else { + 0L + } + } + _countState.update(count) + Iterator((key, count.toString)) + } + + override def handleInitialState( + key: String, + initialState: (String, String, String), + timerValues: TimerValues): Unit = { + val count = 1 + { + if (_countState.exists()) { + _countState.get() + } else { + 0L + } + } + _countState.update(count) + } +} + +case class OutputEventTimeRow(key: String, outputTimestamp: Timestamp) + +// A stateful processor which will return timestamp of the first item from input rows +class ChainingOfOpsStatefulProcessor + extends StatefulProcessor[String, (String, Timestamp), OutputEventTimeRow] { + override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {} + + override def handleInputRows( + key: String, + inputRows: Iterator[(String, Timestamp)], + timerValues: TimerValues): Iterator[OutputEventTimeRow] = { + val timestamp = inputRows.next()._2 + Iterator(OutputEventTimeRow(key, timestamp)) + } +} + +// A basic stateful processor contains composite state variables and TTL +class TTLTestStatefulProcessor + extends StatefulProcessor[String, (String, String), (String, String)] { + import java.time.Duration + + @transient protected var countState: ValueState[Int] = _ + @transient protected var ttlCountState: ValueState[Int] = _ + @transient protected var ttlListState: ListState[Int] = _ + @transient protected var ttlMapState: MapState[String, Int] = _ + + override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = { + countState = getHandle.getValueState[Int]("countState", Encoders.scalaInt, TTLConfig.NONE) + ttlCountState = getHandle + .getValueState[Int]("ttlCountState", Encoders.scalaInt, TTLConfig(Duration.ofMillis(1000))) + ttlListState = getHandle + .getListState[Int]("ttlListState", Encoders.scalaInt, TTLConfig(Duration.ofMillis(1000))) + ttlMapState = getHandle.getMapState[String, Int]( + "ttlMapState", + Encoders.STRING, + Encoders.scalaInt, + TTLConfig(Duration.ofMillis(1000))) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[(String, String)], + timerValues: TimerValues): Iterator[(String, String)] = { + val numOfInputRows = inputRows.toSeq.length + var count = numOfInputRows + var ttlCount = numOfInputRows + var ttlListStateCount = numOfInputRows + var ttlMapStateCount = numOfInputRows + + if (countState.exists()) { + count += countState.get() + } + if (ttlCountState.exists()) { + ttlCount += ttlCountState.get() + } + if (ttlListState.exists()) { + for (value <- ttlListState.get()) { + ttlListStateCount += value + } + } + if (ttlMapState.exists()) { + ttlMapStateCount = ttlMapState.getValue(key) + } + countState.update(count) + if (key != "0") { + ttlCountState.update(ttlCount) + ttlListState.put(Array(ttlListStateCount, ttlListStateCount)) + ttlMapState.updateValue(key, ttlMapStateCount) + } + val output = List( + (s"count-$key", count.toString), + (s"ttlCount-$key", ttlCount.toString), + (s"ttlListState-$key", ttlListStateCount.toString), + (s"ttlMapState-$key", ttlMapStateCount.toString)) + output.iterator + } +} + +class TransformWithStateConnectSuite + extends QueryTest + with RemoteSparkSession + with Logging + with BeforeAndAfterEach { + val testData: Seq[(String, String)] = Seq(("a", "1"), ("b", "1"), ("a", "2")) + val twsAdditionalSQLConf = Seq( + "spark.sql.streaming.stateStore.providerClass" -> + "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider", + "spark.sql.shuffle.partitions" -> "5", + "spark.sql.session.timeZone" -> "UTC", + "spark.sql.streaming.noDataMicroBatches.enabled" -> "false") + + override def beforeAll(): Unit = { + super.beforeAll() + spark.sql("DROP TABLE IF EXISTS my_sink") + } + + override protected def afterEach(): Unit = { + try { + spark.sql("DROP TABLE IF EXISTS my_sink") + } finally { + super.afterEach() + } + } + + test("transformWithState - streaming with state variable, case class type") { + withSQLConf(twsAdditionalSQLConf: _*) { + val session: SparkSession = spark + import session.implicits._ + + withTempPath { dir => + val path = dir.getCanonicalPath + testData + .toDS() + .toDF("key", "value") + .repartition(3) + .write + .parquet(path) + + val testSchema = + StructType(Array(StructField("key", StringType), StructField("value", StringType))) + + val q = spark.readStream + .schema(testSchema) + .option("maxFilesPerTrigger", 1) + .parquet(path) + .as[InputRowForConnectTest] + .groupByKey(x => x.key) + .transformWithState[OutputRowForConnectTest]( + new BasicCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + .writeStream + .format("memory") + .queryName("my_sink") + .start() + + try { + q.processAllAvailable() + eventually(timeout(30.seconds)) { + checkDatasetUnorderly( + spark.table("my_sink").toDF().as[(String, String)], + ("a", "1"), + ("a", "2"), + ("b", "1")) + } + } finally { + q.stop() + } + } + } + } + + test("transformWithState - streaming with initial state") { + withSQLConf(twsAdditionalSQLConf: _*) { + val session: SparkSession = spark + import session.implicits._ + + withTempPath { dir => + val path = dir.getCanonicalPath + testData + .toDS() + .toDF("key", "value") + .repartition(3) + .write + .parquet(path) + + val testSchema = + StructType(Array(StructField("key", StringType), StructField("value", StringType))) + + val initDf = Seq(("init_1", "40.0", "a"), ("init_2", "100.0", "b")) + .toDS() + .groupByKey(x => x._3) + .mapValues(x => x) + + val q = spark.readStream + .schema(testSchema) + .option("maxFilesPerTrigger", 1) + .parquet(path) + .as[(String, String)] + .groupByKey(x => x._1) + .transformWithState( + new TestInitialStatefulProcessor(), + TimeMode.None(), + OutputMode.Update(), + initialState = initDf) + .writeStream + .format("memory") + .queryName("my_sink") + .start() + + try { + q.processAllAvailable() + eventually(timeout(30.seconds)) { + checkDatasetUnorderly( + spark.table("my_sink").toDF().as[(String, String)], + ("a", "2"), + ("a", "3"), + ("b", "2")) + } + } finally { + q.stop() + } + } + } + } + + test("transformWithState - streaming with chaining of operators") { + withSQLConf(twsAdditionalSQLConf: _*) { + val session: SparkSession = spark + import session.implicits._ + + def timestamp(num: Int): Timestamp = { + new Timestamp(num * 1000) + } + + val checkResultFunc: (Dataset[Row], Long) => Unit = { (batchDF, batchId) => + val realDf = batchDF.collect().toSet + if (batchId == 0) { + assert(realDf.isEmpty, s"BatchId: $batchId, RealDF: $realDf") + } else if (batchId == 1) { + // eviction watermark = 15 - 5 = 10 (max event time from batch 0), + // late event watermark = 0 (eviction event time from batch 0) + val expectedDF = Seq(Row(timestamp(10), 1L)).toSet + assert( + realDf == expectedDF, + s"BatchId: $batchId, expectedDf: $expectedDF, RealDF: $realDf") + } else if (batchId == 2) { + // eviction watermark = 25 - 5 = 20, late event watermark = 10; + // row with watermark=5<10 is dropped so it does not show up in the results; + // row with eventTime<=20 are finalized and emitted + val expectedDF = Seq(Row(timestamp(11), 1L), Row(timestamp(15), 1L)).toSet + assert( + realDf == expectedDF, + s"BatchId: $batchId, expectedDf: $expectedDF, RealDF: $realDf") + } + } + + withTempPath { dir => + val path = dir.getCanonicalPath + val curTime = System.currentTimeMillis + val file1 = prepareInputData(path + "/text-test3.csv", Seq("a", "b"), Seq(10, 15)) + file1.setLastModified(curTime + 2L) + val file2 = prepareInputData(path + "/text-test4.csv", Seq("a", "c"), Seq(11, 25)) + file2.setLastModified(curTime + 4L) + val file3 = prepareInputData(path + "/text-test1.csv", Seq("a"), Seq(5)) + file3.setLastModified(curTime + 6L) + + val q = buildTestDf(path, spark) + .select(col("key").as("key"), timestamp_seconds(col("value")).as("eventTime")) + .withWatermark("eventTime", "5 seconds") + .as[(String, Timestamp)] + .groupByKey(x => x._1) + .transformWithState[OutputEventTimeRow]( + new ChainingOfOpsStatefulProcessor(), + "outputTimestamp", + OutputMode.Append()) + .groupBy("outputTimestamp") + .count() + .writeStream + .foreachBatch(checkResultFunc) + .outputMode("Append") + .start() + + q.processAllAvailable() + eventually(timeout(30.seconds)) { + q.stop() + } + } + } + } + + test("transformWithState - streaming with TTL and composite state variables") { + withSQLConf(twsAdditionalSQLConf: _*) { + val session: SparkSession = spark + import session.implicits._ + + val checkResultFunc = (batchDF: Dataset[(String, String)], batchId: Long) => { + if (batchId == 0) { + val expectedDF = Set( + ("count-0", "1"), + ("ttlCount-0", "1"), + ("ttlListState-0", "1"), + ("ttlMapState-0", "1"), + ("count-1", "1"), + ("ttlCount-1", "1"), + ("ttlListState-1", "1"), + ("ttlMapState-1", "1")) + + val realDf = batchDF.collect().toSet + assert(realDf == expectedDF) + + } else if (batchId == 1) { + val expectedDF = Set( + ("count-0", "2"), + ("ttlCount-0", "1"), + ("ttlListState-0", "1"), + ("ttlMapState-0", "1"), + ("count-1", "2"), + ("ttlCount-1", "1"), + ("ttlListState-1", "1"), + ("ttlMapState-1", "1")) + + val realDf = batchDF.collect().toSet + assert(realDf == expectedDF) + } + + if (batchId == 0) { + // let ttl state expires + Thread.sleep(2000) + } + } + + withTempPath { dir => + val path = dir.getCanonicalPath + val curTime = System.currentTimeMillis + val file1 = prepareInputData(path + "/text-test3.csv", Seq("1", "0"), Seq(0, 0)) + file1.setLastModified(curTime + 2L) + val file2 = prepareInputData(path + "/text-test4.csv", Seq("1", "0"), Seq(0, 0)) + file2.setLastModified(curTime + 4L) + + val q = buildTestDf(path, spark) + .as[(String, String)] + .groupByKey(x => x._1) + .transformWithState( + new TTLTestStatefulProcessor(), + TimeMode.ProcessingTime(), + OutputMode.Update()) + .writeStream + .foreachBatch(checkResultFunc) + .outputMode("Update") + .start() + q.processAllAvailable() + + eventually(timeout(30.seconds)) { + q.stop() + } + } + } + } + + test("transformWithState - batch query") { + withSQLConf(twsAdditionalSQLConf: _*) { + val session: SparkSession = spark + import session.implicits._ + + withTempPath { dir => + val path = dir.getCanonicalPath + testData + .toDS() + .toDF("key", "value") + .repartition(3) + .write + .parquet(path) + + val testSchema = + StructType(Array(StructField("key", StringType), StructField("value", StringType))) + + spark.read + .schema(testSchema) + .parquet(path) + .as[InputRowForConnectTest] + .groupByKey(x => x.key) + .transformWithState[OutputRowForConnectTest]( + new BasicCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + .write + .saveAsTable("my_sink") + + checkDatasetUnorderly( + spark.table("my_sink").toDF().as[(String, String)], + ("a", "2"), + ("b", "1")) + } + } + } + + /* Utils functions for tests */ + def prepareInputData(inputPath: String, col1: Seq[String], col2: Seq[Int]): File = { + // Ensure the parent directory exists + val file = Paths.get(inputPath).toFile + val parentDir = file.getParentFile + if (parentDir != null && !parentDir.exists()) { + parentDir.mkdirs() + } + + val writer = new BufferedWriter(new FileWriter(inputPath)) + try { + col1.zip(col2).foreach { case (e1, e2) => + writer.write(s"$e1, $e2\n") + } + } finally { + writer.close() + } + file + } + + def buildTestDf(inputPath: String, sparkSession: SparkSession): DataFrame = { + sparkSession.readStream + .format("csv") + .schema( + new StructType() + .add(StructField("key", StringType)) + .add(StructField("value", StringType))) + .option("maxFilesPerTrigger", 1) + .load(inputPath) + .select(col("key").as("key"), col("value").cast("integer")) + } +} diff --git a/sql/connect/common/pom.xml b/sql/connect/common/pom.xml index e6745df9013ec..3fdd33ab52e19 100644 --- a/sql/connect/common/pom.xml +++ b/sql/connect/common/pom.xml @@ -39,53 +39,85 @@ spark-sql-api_${scala.binary.version} ${project.version} + + org.apache.spark + spark-unsafe_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-tags_${scala.binary.version} + + + org.apache.spark + spark-common-utils_${scala.binary.version} + ${project.version} + org.scala-lang scala-library + + org.scala-lang + scala-reflect + + + org.slf4j + slf4j-api + + + org.apache.commons + commons-lang3 + + + commons-codec + commons-codec + + + com.google.guava + guava + ${connect.guava.version} + compile + com.google.protobuf protobuf-java io.grpc - grpc-netty - ${io.grpc.version} + grpc-api io.grpc grpc-protobuf - ${io.grpc.version} io.grpc - grpc-services - ${io.grpc.version} + grpc-stub - io.grpc - grpc-stub - ${io.grpc.version} + com.google.api.grpc + proto-google-common-protos - io.grpc - grpc-inprocess - ${io.grpc.version} + org.apache.arrow + arrow-format + + + org.apache.arrow + arrow-vector - io.netty - netty-codec-http2 - ${netty.version} + org.apache.arrow + arrow-memory-core - io.netty - netty-handler-proxy - ${netty.version} + org.json4s + json4s-core_${scala.binary.version} - io.netty - netty-transport-native-unix-common - ${netty.version} + org.json4s + json4s-ast_${scala.binary.version} org.apache.tomcat @@ -114,6 +146,50 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes + + + + org.apache.maven.plugins + maven-dependency-plugin + + + + org.apache.spark:spark-connect-shims_${scala.binary.version} + + + + + org.spark-project.spark:unused + + + org.apache.tomcat:annotations-api + + + org.apache.spark:spark-tags_${scala.binary.version}:test-jar + + + org.scalatest:scalatest_${scala.binary.version} + + + org.scalatestplus:scalacheck-1-18_${scala.binary.version} + + + org.scalatestplus:mockito-5-12_${scala.binary.version} + + + org.scalatestplus:selenium-4-21_${scala.binary.version} + + + org.junit.jupiter:junit-jupiter + + + com.github.sbt.junit:jupiter-interface + + + + + + org.codehaus.mojo @@ -142,8 +218,26 @@ org.spark-project.spark:unused org.apache.tomcat:annotations-api + com.google.guava:* + + + com.google.guava:guava + + com/google/thirdparty/** + + + + + + com.google.common + ${spark.shade.packageName}.connect.guava + + com.google.common.** + + + diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala index b15e8c28df744..090907a538c72 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala @@ -141,7 +141,7 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends sql.KeyValueGroupedDa statefulProcessor: StatefulProcessor[K, V, U], timeMode: TimeMode, outputMode: OutputMode): Dataset[U] = - unsupported() + transformWithStateHelper(statefulProcessor, timeMode, outputMode) /** @inheritdoc */ private[sql] def transformWithState[U: Encoder, S: Encoder]( @@ -149,20 +149,40 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends sql.KeyValueGroupedDa timeMode: TimeMode, outputMode: OutputMode, initialState: sql.KeyValueGroupedDataset[K, S]): Dataset[U] = - unsupported() + transformWithStateHelper(statefulProcessor, timeMode, outputMode, Some(initialState)) /** @inheritdoc */ override private[sql] def transformWithState[U: Encoder]( statefulProcessor: StatefulProcessor[K, V, U], eventTimeColumnName: String, - outputMode: OutputMode): Dataset[U] = unsupported() + outputMode: OutputMode): Dataset[U] = + transformWithStateHelper( + statefulProcessor, + TimeMode.EventTime(), + outputMode, + eventTimeColumnName = eventTimeColumnName) /** @inheritdoc */ override private[sql] def transformWithState[U: Encoder, S: Encoder]( statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], eventTimeColumnName: String, outputMode: OutputMode, - initialState: sql.KeyValueGroupedDataset[K, S]): Dataset[U] = unsupported() + initialState: sql.KeyValueGroupedDataset[K, S]): Dataset[U] = + transformWithStateHelper( + statefulProcessor, + TimeMode.EventTime(), + outputMode, + Some(initialState), + eventTimeColumnName) + + // This is an interface, and it should not be used. The real implementation is in the + // inherited class. + protected[sql] def transformWithStateHelper[U: Encoder, S: Encoder]( + statefulProcessor: StatefulProcessor[K, V, U], + timeMode: TimeMode, + outputMode: OutputMode, + initialState: Option[sql.KeyValueGroupedDataset[K, S]] = None, + eventTimeColumnName: String = ""): Dataset[U] = unsupported() // Overrides... /** @inheritdoc */ @@ -602,7 +622,6 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( } val initialStateImpl = if (initialState.isDefined) { - assert(initialState.get.isInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, _]]) initialState.get.asInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, _]] } else { null @@ -632,6 +651,53 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( } } + override protected[sql] def transformWithStateHelper[U: Encoder, S: Encoder]( + statefulProcessor: StatefulProcessor[K, V, U], + timeMode: TimeMode, + outputMode: OutputMode, + initialState: Option[sql.KeyValueGroupedDataset[K, S]] = None, + eventTimeColumnName: String = ""): Dataset[U] = { + val outputEncoder = agnosticEncoderFor[U] + val stateEncoder = agnosticEncoderFor[S] + val inputEncoders: Seq[AgnosticEncoder[_]] = Seq(kEncoder, stateEncoder, ivEncoder) + + // SparkUserDefinedFunction is creating a udfPacket where the input function are + // being java serialized into bytes; we pass in `statefulProcessor` as function so it can be + // serialized into bytes and deserialized back on connect server + val sparkUserDefinedFunc = + SparkUserDefinedFunction(statefulProcessor, inputEncoders, outputEncoder) + val funcProto = UdfToProtoUtils.toProto(sparkUserDefinedFunc) + + val initialStateImpl = if (initialState.isDefined) { + initialState.get.asInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, _]] + } else { + null + } + + sparkSession.newDataset[U](outputEncoder) { builder => + val twsBuilder = builder.getGroupMapBuilder + val twsInfoBuilder = proto.TransformWithStateInfo.newBuilder() + if (!eventTimeColumnName.isEmpty) { + twsInfoBuilder.setEventTimeColumnName(eventTimeColumnName) + } + twsBuilder + .setInput(plan.getRoot) + .addAllGroupingExpressions(groupingExprs) + .setFunc(funcProto) + .setOutputMode(outputMode.toString) + .setTransformWithStateInfo( + twsInfoBuilder + // we pass time mode as string here and deterministically restored on server + .setTimeMode(timeMode.toString) + .build()) + if (initialStateImpl != null) { + twsBuilder + .addAllInitialGroupingExpressions(initialStateImpl.groupingExprs) + .setInitialInput(initialStateImpl.plan.getRoot) + } + } + } + private def getUdf[U: Encoder](nf: AnyRef, outputEncoder: AgnosticEncoder[U])( inEncoders: AgnosticEncoder[_]*): proto.CommonInlineUserDefinedFunction = { val inputEncoders = kEncoder +: inEncoders // Apply keyAs changes by setting kEncoder diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala index 4618c7e24d4ac..ceeece073da65 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala @@ -388,7 +388,7 @@ object ArrowDeserializers { } } - case (TransformingEncoder(_, encoder, provider), v) => + case (TransformingEncoder(_, encoder, provider, _), v) => new Deserializer[Any] { private[this] val codec = provider() private[this] val deserializer = deserializerFor(encoder, v, timeZoneId) diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala index 584a318f039d8..a74fce3a8abaf 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala @@ -485,7 +485,7 @@ object ArrowSerializer { o => getter.invoke(o) } - case (TransformingEncoder(_, encoder, provider), v) => + case (TransformingEncoder(_, encoder, provider, _), v) => new Serializer { private[this] val codec = provider().asInstanceOf[Codec[Any, Any]] private[this] val delegate: Serializer = serializerFor(encoder, v) diff --git a/sql/connect/server/pom.xml b/sql/connect/server/pom.xml index bb30b01f778de..58d7b17013fa5 100644 --- a/sql/connect/server/pom.xml +++ b/sql/connect/server/pom.xml @@ -193,17 +193,14 @@ io.grpc grpc-netty - ${io.grpc.version} io.grpc grpc-protobuf - ${io.grpc.version} io.grpc grpc-services - ${io.grpc.version} io.grpc @@ -213,7 +210,6 @@ io.netty netty-codec-http2 - ${netty.version} provided diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 24fc1275d4823..734eb394ca680 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -54,7 +54,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils} import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin} import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, Assignment, CoGroup, CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark, DeleteAction, DeserializeToObject, Except, FlatMapGroupsWithState, InsertAction, InsertStarAction, Intersect, JoinWith, LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, MergeAction, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, Unpivot, UnresolvedHint, UpdateAction, UpdateStarAction} +import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, Assignment, CoGroup, CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark, DeleteAction, DeserializeToObject, Except, FlatMapGroupsWithState, InsertAction, InsertStarAction, Intersect, JoinWith, LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, MergeAction, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TimeModes, TransformWithState, TypedFilter, Union, Unpivot, UnresolvedHint, UpdateAction, UpdateEventTimeWatermarkColumn, UpdateStarAction} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, TreePattern} import org.apache.spark.sql.catalyst.types.DataTypeUtils @@ -81,7 +81,7 @@ import org.apache.spark.sql.execution.stat.StatFunctions import org.apache.spark.sql.execution.streaming.GroupStateImpl.groupStateTimeoutFromString import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper import org.apache.spark.sql.expressions.{Aggregator, ReduceAggregator, SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction} -import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StreamingQuery, StreamingQueryListener, StreamingQueryProgress, Trigger} +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, StreamingQuery, StreamingQueryListener, StreamingQueryProgress, Trigger} import org.apache.spark.sql.types._ import org.apache.spark.sql.util.{ArrowUtils, CaseInsensitiveStringMap} import org.apache.spark.storage.CacheId @@ -684,7 +684,71 @@ class SparkConnectPlanner( rel.getGroupingExpressionsList, rel.getSortingExpressionsList) - if (rel.hasIsMapGroupsWithState) { + if (rel.hasTransformWithStateInfo) { + val hasInitialState = !rel.getInitialGroupingExpressionsList.isEmpty && rel.hasInitialInput + + val twsInfo = rel.getTransformWithStateInfo + val keyDeserializer = udf.inputDeserializer(ds.groupingAttributes) + val outputAttr = udf.outputObjAttr + + val timeMode = TimeModes(twsInfo.getTimeMode) + val outputMode = InternalOutputModes(rel.getOutputMode) + + val twsNode = if (hasInitialState) { + val statefulProcessor = unpackedUdf.function + .asInstanceOf[StatefulProcessorWithInitialState[Any, Any, Any, Any]] + val initDs = UntypedKeyValueGroupedDataset( + rel.getInitialInput, + rel.getInitialGroupingExpressionsList, + rel.getSortingExpressionsList) + new TransformWithState( + keyDeserializer, + ds.valueDeserializer, + ds.groupingAttributes, + ds.dataAttributes, + statefulProcessor, + timeMode, + outputMode, + udf.inEnc.asInstanceOf[ExpressionEncoder[Any]], + outputAttr, + ds.analyzed, + hasInitialState, + initDs.groupingAttributes, + initDs.dataAttributes, + initDs.valueDeserializer, + initDs.analyzed) + } else { + val statefulProcessor = + unpackedUdf.function.asInstanceOf[StatefulProcessor[Any, Any, Any]] + new TransformWithState( + keyDeserializer, + ds.valueDeserializer, + ds.groupingAttributes, + ds.dataAttributes, + statefulProcessor, + timeMode, + outputMode, + udf.inEnc.asInstanceOf[ExpressionEncoder[Any]], + outputAttr, + ds.analyzed, + hasInitialState, + ds.groupingAttributes, + ds.dataAttributes, + keyDeserializer, + LocalRelation(ds.vEncoder.schema)) + } + val serializedPlan = SerializeFromObject(udf.outputNamedExpression, twsNode) + + if (twsInfo.hasEventTimeColumnName) { + val eventTimeWrappedPlan = UpdateEventTimeWatermarkColumn( + UnresolvedAttribute(twsInfo.getEventTimeColumnName), + None, + serializedPlan) + eventTimeWrappedPlan + } else { + serializedPlan + } + } else if (rel.hasIsMapGroupsWithState) { val hasInitialState = !rel.getInitialGroupingExpressionsList.isEmpty && rel.hasInitialInput val initialDs = if (hasInitialState) { UntypedKeyValueGroupedDataset( diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala index a156be189c650..c158de48574f3 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala @@ -249,7 +249,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging { // Visible for testing. private[connect] def periodicMaintenance(timeoutNs: Long): Unit = { // Find any detached executions that expired and should be removed. - logInfo("Started periodic run of SparkConnectExecutionManager maintenance.") + logDebug("Started periodic run of SparkConnectExecutionManager maintenance.") val nowNs = System.nanoTime() executions.forEach((_, executeHolder) => { @@ -266,7 +266,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging { } }) - logInfo("Finished periodic run of SparkConnectExecutionManager maintenance.") + logDebug("Finished periodic run of SparkConnectExecutionManager maintenance.") } // For testing. diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala index 8581bb7b98f05..d8b3e21e7ad61 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala @@ -234,7 +234,7 @@ class SparkConnectSessionManager extends Logging { defaultInactiveTimeoutMs: Long, ignoreCustomTimeout: Boolean): Unit = { // Find any sessions that expired and should be removed. - logInfo("Started periodic run of SparkConnectSessionManager maintenance.") + logDebug("Started periodic run of SparkConnectSessionManager maintenance.") def shouldExpire(info: SessionHolderInfo, nowMs: Long): Boolean = { val timeoutMs = if (info.customInactiveTimeoutMs.isDefined && !ignoreCustomTimeout) { @@ -262,7 +262,7 @@ class SparkConnectSessionManager extends Logging { } }) - logInfo("Finished periodic run of SparkConnectSessionManager maintenance.") + logDebug("Finished periodic run of SparkConnectSessionManager maintenance.") } private def newIsolatedSession(): SparkSession = { diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ui/SparkConnectServerSessionPage.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ui/SparkConnectServerSessionPage.scala index c35eb58edfbe7..1f335c9ce0051 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ui/SparkConnectServerSessionPage.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ui/SparkConnectServerSessionPage.scala @@ -43,7 +43,7 @@ private[ui] class SparkConnectServerSessionPage(parent: SparkConnectServerTab) store .getSession(sessionId) .map { sessionStat => - generateBasicStats() ++ + generateBasicStats(sessionId) ++
++

User @@ -64,9 +64,12 @@ private[ui] class SparkConnectServerSessionPage(parent: SparkConnectServerTab) } /** Generate basic stats of the Spark Connect Server */ - private def generateBasicStats(): Seq[Node] = { + private def generateBasicStats(sessionId: String): Seq[Node] = { val timeSinceStart = System.currentTimeMillis() - startTime.getTime
    +
  • + Session ID: {sessionId} +
  • Started at: {formatDate(startTime)}
  • diff --git a/sql/core/benchmarks/StateStoreBasicOperationsBenchmark-jdk21-results.txt b/sql/core/benchmarks/StateStoreBasicOperationsBenchmark-jdk21-results.txt index bd2311634a5bc..1e3fc590644ad 100644 --- a/sql/core/benchmarks/StateStoreBasicOperationsBenchmark-jdk21-results.txt +++ b/sql/core/benchmarks/StateStoreBasicOperationsBenchmark-jdk21-results.txt @@ -2,143 +2,143 @@ put rows ================================================================================================ -OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure +OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure AMD EPYC 7763 64-Core Processor putting 10000 rows (10000 rows to overwrite - rate 100): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative --------------------------------------------------------------------------------------------------------------------------------------- -In-memory 8 9 1 1.2 822.1 1.0X -RocksDB (trackTotalNumberOfRows: true) 45 47 2 0.2 4455.4 0.2X -RocksDB (trackTotalNumberOfRows: false) 17 17 1 0.6 1655.9 0.5X +In-memory 8 9 1 1.2 815.6 1.0X +RocksDB (trackTotalNumberOfRows: true) 46 47 2 0.2 4559.1 0.2X +RocksDB (trackTotalNumberOfRows: false) 17 18 1 0.6 1678.7 0.5X -OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure +OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure AMD EPYC 7763 64-Core Processor putting 10000 rows (5000 rows to overwrite - rate 50): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------- -In-memory 8 9 1 1.2 805.6 1.0X -RocksDB (trackTotalNumberOfRows: true) 46 47 1 0.2 4561.6 0.2X -RocksDB (trackTotalNumberOfRows: false) 16 17 1 0.6 1637.5 0.5X +In-memory 8 9 1 1.3 798.1 1.0X +RocksDB (trackTotalNumberOfRows: true) 47 48 2 0.2 4659.8 0.2X +RocksDB (trackTotalNumberOfRows: false) 17 17 1 0.6 1663.4 0.5X -OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure +OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure AMD EPYC 7763 64-Core Processor putting 10000 rows (1000 rows to overwrite - rate 10): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------- -In-memory 8 8 1 1.3 782.0 1.0X -RocksDB (trackTotalNumberOfRows: true) 45 47 1 0.2 4537.4 0.2X -RocksDB (trackTotalNumberOfRows: false) 16 17 1 0.6 1633.0 0.5X +In-memory 8 9 1 1.3 794.9 1.0X +RocksDB (trackTotalNumberOfRows: true) 46 48 1 0.2 4625.7 0.2X +RocksDB (trackTotalNumberOfRows: false) 17 17 1 0.6 1660.7 0.5X -OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure +OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure AMD EPYC 7763 64-Core Processor putting 10000 rows (0 rows to overwrite - rate 0): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative --------------------------------------------------------------------------------------------------------------------------------- -In-memory 8 8 1 1.3 783.3 1.0X -RocksDB (trackTotalNumberOfRows: true) 45 46 1 0.2 4484.9 0.2X -RocksDB (trackTotalNumberOfRows: false) 16 17 1 0.6 1641.4 0.5X +In-memory 8 8 1 1.3 788.6 1.0X +RocksDB (trackTotalNumberOfRows: true) 46 47 1 0.2 4557.0 0.2X +RocksDB (trackTotalNumberOfRows: false) 17 17 1 0.6 1650.3 0.5X ================================================================================================ merge rows ================================================================================================ -OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure +OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure AMD EPYC 7763 64-Core Processor merging 10000 rows with 10 values per key (10000 rows to overwrite - rate 100): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------------------------------------------- -RocksDB (trackTotalNumberOfRows: true) 565 579 7 0.0 56471.0 1.0X -RocksDB (trackTotalNumberOfRows: false) 182 188 3 0.1 18161.0 3.1X +RocksDB (trackTotalNumberOfRows: true) 574 585 6 0.0 57387.8 1.0X +RocksDB (trackTotalNumberOfRows: false) 181 186 3 0.1 18065.2 3.2X -OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure +OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure AMD EPYC 7763 64-Core Processor merging 10000 rows with 10 values per key (5000 rows to overwrite - rate 50): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------------------------ -RocksDB (trackTotalNumberOfRows: true) 500 512 5 0.0 50023.6 1.0X -RocksDB (trackTotalNumberOfRows: false) 183 188 3 0.1 18312.9 2.7X +RocksDB (trackTotalNumberOfRows: true) 504 515 5 0.0 50382.4 1.0X +RocksDB (trackTotalNumberOfRows: false) 179 185 3 0.1 17882.2 2.8X -OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure +OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure AMD EPYC 7763 64-Core Processor merging 10000 rows with 10 values per key (1000 rows to overwrite - rate 10): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------------------------ -RocksDB (trackTotalNumberOfRows: true) 436 447 5 0.0 43613.9 1.0X -RocksDB (trackTotalNumberOfRows: false) 181 186 3 0.1 18065.5 2.4X +RocksDB (trackTotalNumberOfRows: true) 442 455 6 0.0 44235.2 1.0X +RocksDB (trackTotalNumberOfRows: false) 180 185 3 0.1 17971.5 2.5X -OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure +OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure AMD EPYC 7763 64-Core Processor merging 10000 rows with 10 values per key (0 rows to overwrite - rate 0): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------------------------------------- -RocksDB (trackTotalNumberOfRows: true) 413 425 5 0.0 41349.9 1.0X -RocksDB (trackTotalNumberOfRows: false) 181 187 4 0.1 18075.6 2.3X +RocksDB (trackTotalNumberOfRows: true) 424 436 5 0.0 42391.9 1.0X +RocksDB (trackTotalNumberOfRows: false) 179 185 4 0.1 17923.5 2.4X ================================================================================================ delete rows ================================================================================================ -OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure +OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure AMD EPYC 7763 64-Core Processor trying to delete 10000 rows from 10000 rows(10000 rows are non-existing - rate 100): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------------------------------- -In-memory 0 0 0 26.8 37.3 1.0X -RocksDB (trackTotalNumberOfRows: true) 44 45 1 0.2 4396.1 0.0X -RocksDB (trackTotalNumberOfRows: false) 15 16 1 0.7 1522.7 0.0X +In-memory 0 1 0 27.1 36.9 1.0X +RocksDB (trackTotalNumberOfRows: true) 45 46 1 0.2 4470.0 0.0X +RocksDB (trackTotalNumberOfRows: false) 16 17 1 0.6 1583.0 0.0X -OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure +OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure AMD EPYC 7763 64-Core Processor trying to delete 10000 rows from 10000 rows(5000 rows are non-existing - rate 50): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ----------------------------------------------------------------------------------------------------------------------------------------------------------------- -In-memory 7 7 0 1.5 666.2 1.0X -RocksDB (trackTotalNumberOfRows: true) 44 46 1 0.2 4392.1 0.2X -RocksDB (trackTotalNumberOfRows: false) 15 16 0 0.7 1511.0 0.4X +In-memory 7 7 0 1.5 651.4 1.0X +RocksDB (trackTotalNumberOfRows: true) 46 47 1 0.2 4580.3 0.1X +RocksDB (trackTotalNumberOfRows: false) 16 17 0 0.6 1582.7 0.4X -OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure +OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure AMD EPYC 7763 64-Core Processor trying to delete 10000 rows from 10000 rows(1000 rows are non-existing - rate 10): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ----------------------------------------------------------------------------------------------------------------------------------------------------------------- -In-memory 7 8 1 1.4 714.3 1.0X -RocksDB (trackTotalNumberOfRows: true) 44 45 1 0.2 4362.1 0.2X -RocksDB (trackTotalNumberOfRows: false) 15 16 1 0.7 1520.6 0.5X +In-memory 7 8 0 1.4 713.7 1.0X +RocksDB (trackTotalNumberOfRows: true) 45 47 1 0.2 4538.6 0.2X +RocksDB (trackTotalNumberOfRows: false) 16 16 0 0.6 1579.3 0.5X -OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure +OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure AMD EPYC 7763 64-Core Processor trying to delete 10000 rows from 10000 rows(0 rows are non-existing - rate 0): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------------------------- -In-memory 7 8 0 1.4 725.8 1.0X -RocksDB (trackTotalNumberOfRows: true) 43 45 1 0.2 4310.1 0.2X -RocksDB (trackTotalNumberOfRows: false) 15 16 1 0.7 1528.4 0.5X +In-memory 7 8 0 1.4 716.9 1.0X +RocksDB (trackTotalNumberOfRows: true) 45 46 1 0.2 4459.8 0.2X +RocksDB (trackTotalNumberOfRows: false) 16 16 1 0.6 1580.7 0.5X ================================================================================================ evict rows ================================================================================================ -OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure +OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure AMD EPYC 7763 64-Core Processor evicting 10000 rows (maxTimestampToEvictInMillis: 9999) from 10000 rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------------------- -In-memory 7 8 0 1.4 715.6 1.0X -RocksDB (trackTotalNumberOfRows: true) 44 45 1 0.2 4386.8 0.2X -RocksDB (trackTotalNumberOfRows: false) 17 17 0 0.6 1686.7 0.4X +In-memory 7 7 0 1.5 689.5 1.0X +RocksDB (trackTotalNumberOfRows: true) 44 45 1 0.2 4424.0 0.2X +RocksDB (trackTotalNumberOfRows: false) 18 18 0 0.6 1784.2 0.4X -OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure +OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure AMD EPYC 7763 64-Core Processor evicting 5000 rows (maxTimestampToEvictInMillis: 4999) from 10000 rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------------------ -In-memory 7 7 0 1.5 667.7 1.0X -RocksDB (trackTotalNumberOfRows: true) 23 24 1 0.4 2292.5 0.3X -RocksDB (trackTotalNumberOfRows: false) 10 10 0 1.0 994.3 0.7X +In-memory 6 7 0 1.5 650.0 1.0X +RocksDB (trackTotalNumberOfRows: true) 23 24 1 0.4 2347.8 0.3X +RocksDB (trackTotalNumberOfRows: false) 10 11 0 1.0 1037.1 0.6X -OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure +OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure AMD EPYC 7763 64-Core Processor evicting 1000 rows (maxTimestampToEvictInMillis: 999) from 10000 rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ----------------------------------------------------------------------------------------------------------------------------------------------------- -In-memory 6 6 0 1.7 603.8 1.0X -RocksDB (trackTotalNumberOfRows: true) 7 8 0 1.3 749.5 0.8X -RocksDB (trackTotalNumberOfRows: false) 5 5 0 2.1 482.2 1.3X +In-memory 6 6 0 1.7 585.4 1.0X +RocksDB (trackTotalNumberOfRows: true) 8 8 0 1.3 766.5 0.8X +RocksDB (trackTotalNumberOfRows: false) 5 5 0 2.0 503.2 1.2X -OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure +OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure AMD EPYC 7763 64-Core Processor evicting 0 rows (maxTimestampToEvictInMillis: -1) from 10000 rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------------- -In-memory 0 0 0 23.7 42.1 1.0X -RocksDB (trackTotalNumberOfRows: true) 3 4 0 2.9 345.1 0.1X -RocksDB (trackTotalNumberOfRows: false) 3 4 0 2.9 344.6 0.1X +In-memory 0 0 0 25.0 40.1 1.0X +RocksDB (trackTotalNumberOfRows: true) 4 4 0 2.8 359.1 0.1X +RocksDB (trackTotalNumberOfRows: false) 4 4 0 2.8 359.9 0.1X diff --git a/sql/core/benchmarks/StateStoreBasicOperationsBenchmark-results.txt b/sql/core/benchmarks/StateStoreBasicOperationsBenchmark-results.txt index 4a7c21d01c80e..a3688c16f100a 100644 --- a/sql/core/benchmarks/StateStoreBasicOperationsBenchmark-results.txt +++ b/sql/core/benchmarks/StateStoreBasicOperationsBenchmark-results.txt @@ -2,143 +2,143 @@ put rows ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure +OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure AMD EPYC 7763 64-Core Processor putting 10000 rows (10000 rows to overwrite - rate 100): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative --------------------------------------------------------------------------------------------------------------------------------------- -In-memory 8 10 1 1.2 842.4 1.0X -RocksDB (trackTotalNumberOfRows: true) 45 47 2 0.2 4529.0 0.2X -RocksDB (trackTotalNumberOfRows: false) 16 17 1 0.6 1635.3 0.5X +In-memory 8 9 1 1.2 816.3 1.0X +RocksDB (trackTotalNumberOfRows: true) 45 47 1 0.2 4514.1 0.2X +RocksDB (trackTotalNumberOfRows: false) 17 18 1 0.6 1682.7 0.5X -OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure +OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure AMD EPYC 7763 64-Core Processor putting 10000 rows (5000 rows to overwrite - rate 50): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------- -In-memory 8 10 1 1.2 831.7 1.0X -RocksDB (trackTotalNumberOfRows: true) 47 48 1 0.2 4662.3 0.2X -RocksDB (trackTotalNumberOfRows: false) 16 17 1 0.6 1625.5 0.5X +In-memory 8 10 1 1.2 811.7 1.0X +RocksDB (trackTotalNumberOfRows: true) 47 49 1 0.2 4694.9 0.2X +RocksDB (trackTotalNumberOfRows: false) 17 18 1 0.6 1680.2 0.5X -OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure +OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure AMD EPYC 7763 64-Core Processor putting 10000 rows (1000 rows to overwrite - rate 10): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------- -In-memory 8 9 1 1.2 802.0 1.0X -RocksDB (trackTotalNumberOfRows: true) 46 48 1 0.2 4634.4 0.2X -RocksDB (trackTotalNumberOfRows: false) 16 17 1 0.6 1616.5 0.5X +In-memory 8 9 1 1.3 786.5 1.0X +RocksDB (trackTotalNumberOfRows: true) 47 48 1 0.2 4679.7 0.2X +RocksDB (trackTotalNumberOfRows: false) 17 18 1 0.6 1650.0 0.5X -OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure +OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure AMD EPYC 7763 64-Core Processor putting 10000 rows (0 rows to overwrite - rate 0): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative --------------------------------------------------------------------------------------------------------------------------------- -In-memory 8 9 1 1.2 828.2 1.0X -RocksDB (trackTotalNumberOfRows: true) 46 47 1 0.2 4593.4 0.2X -RocksDB (trackTotalNumberOfRows: false) 16 17 1 0.6 1596.2 0.5X +In-memory 8 8 1 1.3 778.0 1.0X +RocksDB (trackTotalNumberOfRows: true) 46 48 1 0.2 4629.4 0.2X +RocksDB (trackTotalNumberOfRows: false) 17 17 1 0.6 1664.9 0.5X ================================================================================================ merge rows ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure +OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure AMD EPYC 7763 64-Core Processor merging 10000 rows with 10 values per key (10000 rows to overwrite - rate 100): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------------------------------------------- -RocksDB (trackTotalNumberOfRows: true) 574 593 7 0.0 57382.2 1.0X -RocksDB (trackTotalNumberOfRows: false) 186 191 3 0.1 18572.6 3.1X +RocksDB (trackTotalNumberOfRows: true) 570 585 6 0.0 56996.2 1.0X +RocksDB (trackTotalNumberOfRows: false) 184 190 3 0.1 18411.4 3.1X -OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure +OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure AMD EPYC 7763 64-Core Processor merging 10000 rows with 10 values per key (5000 rows to overwrite - rate 50): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------------------------ -RocksDB (trackTotalNumberOfRows: true) 502 513 5 0.0 50183.7 1.0X -RocksDB (trackTotalNumberOfRows: false) 185 191 3 0.1 18542.0 2.7X +RocksDB (trackTotalNumberOfRows: true) 493 505 5 0.0 49327.2 1.0X +RocksDB (trackTotalNumberOfRows: false) 181 188 3 0.1 18140.8 2.7X -OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure +OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure AMD EPYC 7763 64-Core Processor merging 10000 rows with 10 values per key (1000 rows to overwrite - rate 10): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------------------------ -RocksDB (trackTotalNumberOfRows: true) 439 453 6 0.0 43896.3 1.0X -RocksDB (trackTotalNumberOfRows: false) 184 190 3 0.1 18384.9 2.4X +RocksDB (trackTotalNumberOfRows: true) 435 448 5 0.0 43484.3 1.0X +RocksDB (trackTotalNumberOfRows: false) 183 188 3 0.1 18289.1 2.4X -OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure +OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure AMD EPYC 7763 64-Core Processor merging 10000 rows with 10 values per key (0 rows to overwrite - rate 0): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------------------------------------- -RocksDB (trackTotalNumberOfRows: true) 421 433 5 0.0 42057.9 1.0X -RocksDB (trackTotalNumberOfRows: false) 184 192 3 0.1 18421.4 2.3X +RocksDB (trackTotalNumberOfRows: true) 416 432 5 0.0 41606.2 1.0X +RocksDB (trackTotalNumberOfRows: false) 183 189 3 0.1 18282.2 2.3X ================================================================================================ delete rows ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure +OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure AMD EPYC 7763 64-Core Processor trying to delete 10000 rows from 10000 rows(10000 rows are non-existing - rate 100): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------------------------------- -In-memory 0 1 0 26.3 38.0 1.0X -RocksDB (trackTotalNumberOfRows: true) 45 46 1 0.2 4510.4 0.0X -RocksDB (trackTotalNumberOfRows: false) 16 17 0 0.6 1585.4 0.0X +In-memory 0 1 0 26.6 37.7 1.0X +RocksDB (trackTotalNumberOfRows: true) 45 47 1 0.2 4514.1 0.0X +RocksDB (trackTotalNumberOfRows: false) 16 17 0 0.6 1587.8 0.0X -OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure +OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure AMD EPYC 7763 64-Core Processor trying to delete 10000 rows from 10000 rows(5000 rows are non-existing - rate 50): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ----------------------------------------------------------------------------------------------------------------------------------------------------------------- -In-memory 7 7 0 1.5 673.9 1.0X -RocksDB (trackTotalNumberOfRows: true) 46 47 1 0.2 4566.1 0.1X -RocksDB (trackTotalNumberOfRows: false) 16 16 0 0.6 1572.0 0.4X +In-memory 6 7 1 1.6 644.9 1.0X +RocksDB (trackTotalNumberOfRows: true) 45 47 1 0.2 4524.6 0.1X +RocksDB (trackTotalNumberOfRows: false) 16 17 1 0.6 1579.1 0.4X -OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure +OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure AMD EPYC 7763 64-Core Processor trying to delete 10000 rows from 10000 rows(1000 rows are non-existing - rate 10): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ----------------------------------------------------------------------------------------------------------------------------------------------------------------- -In-memory 7 8 1 1.4 725.8 1.0X -RocksDB (trackTotalNumberOfRows: true) 45 46 1 0.2 4481.0 0.2X -RocksDB (trackTotalNumberOfRows: false) 16 16 0 0.6 1582.0 0.5X +In-memory 7 8 1 1.4 698.2 1.0X +RocksDB (trackTotalNumberOfRows: true) 45 46 1 0.2 4481.1 0.2X +RocksDB (trackTotalNumberOfRows: false) 16 17 1 0.6 1585.3 0.4X -OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure +OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure AMD EPYC 7763 64-Core Processor trying to delete 10000 rows from 10000 rows(0 rows are non-existing - rate 0): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------------------------- -In-memory 7 8 1 1.4 736.3 1.0X -RocksDB (trackTotalNumberOfRows: true) 44 46 1 0.2 4449.2 0.2X -RocksDB (trackTotalNumberOfRows: false) 16 16 1 0.6 1570.2 0.5X +In-memory 7 8 1 1.4 707.0 1.0X +RocksDB (trackTotalNumberOfRows: true) 43 45 1 0.2 4326.6 0.2X +RocksDB (trackTotalNumberOfRows: false) 16 17 1 0.6 1560.6 0.5X ================================================================================================ evict rows ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure +OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure AMD EPYC 7763 64-Core Processor evicting 10000 rows (maxTimestampToEvictInMillis: 9999) from 10000 rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------------------- -In-memory 7 8 0 1.4 719.2 1.0X -RocksDB (trackTotalNumberOfRows: true) 43 44 1 0.2 4313.7 0.2X -RocksDB (trackTotalNumberOfRows: false) 17 17 1 0.6 1652.9 0.4X +In-memory 7 7 0 1.4 693.7 1.0X +RocksDB (trackTotalNumberOfRows: true) 43 44 1 0.2 4285.3 0.2X +RocksDB (trackTotalNumberOfRows: false) 17 18 0 0.6 1726.3 0.4X -OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure +OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure AMD EPYC 7763 64-Core Processor evicting 5000 rows (maxTimestampToEvictInMillis: 4999) from 10000 rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------------------ -In-memory 7 7 0 1.5 670.7 1.0X -RocksDB (trackTotalNumberOfRows: true) 23 24 1 0.4 2332.2 0.3X -RocksDB (trackTotalNumberOfRows: false) 10 11 0 1.0 1026.8 0.7X +In-memory 6 7 0 1.5 646.3 1.0X +RocksDB (trackTotalNumberOfRows: true) 24 24 0 0.4 2351.2 0.3X +RocksDB (trackTotalNumberOfRows: false) 11 11 0 0.9 1062.9 0.6X -OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure +OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure AMD EPYC 7763 64-Core Processor evicting 1000 rows (maxTimestampToEvictInMillis: 999) from 10000 rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ----------------------------------------------------------------------------------------------------------------------------------------------------- -In-memory 6 7 0 1.6 610.3 1.0X -RocksDB (trackTotalNumberOfRows: true) 8 8 0 1.3 767.9 0.8X -RocksDB (trackTotalNumberOfRows: false) 5 5 0 2.0 507.6 1.2X +In-memory 6 6 0 1.7 587.7 1.0X +RocksDB (trackTotalNumberOfRows: true) 8 8 0 1.3 784.7 0.7X +RocksDB (trackTotalNumberOfRows: false) 5 6 0 1.9 529.1 1.1X -OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure +OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure AMD EPYC 7763 64-Core Processor evicting 0 rows (maxTimestampToEvictInMillis: -1) from 10000 rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------------- -In-memory 0 0 0 23.1 43.3 1.0X -RocksDB (trackTotalNumberOfRows: true) 4 4 0 2.7 370.8 0.1X -RocksDB (trackTotalNumberOfRows: false) 4 4 0 2.7 371.8 0.1X +In-memory 0 0 0 23.2 43.2 1.0X +RocksDB (trackTotalNumberOfRows: true) 4 4 0 2.6 387.5 0.1X +RocksDB (trackTotalNumberOfRows: false) 4 4 0 2.6 389.4 0.1X diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala index 9d1ca3ce5fa0a..9652add9cac38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala @@ -69,8 +69,8 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.storage.StorageLevel import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.util.{NextIterator, Utils} import org.apache.spark.util.ArrayImplicits._ -import org.apache.spark.util.Utils private[sql] object Dataset { val curId = new java.util.concurrent.atomic.AtomicLong() @@ -1677,21 +1677,19 @@ class Dataset[T] private[sql]( val gen = new JacksonGenerator(rowSchema, writer, new JSONOptions(Map.empty[String, String], sessionLocalTimeZone)) - new Iterator[String] { + new NextIterator[String] { private val toRow = exprEnc.createSerializer() - override def hasNext: Boolean = iter.hasNext - override def next(): String = { + override def close(): Unit = { gen.close() } + override def getNext(): String = { + if (!iter.hasNext) { + finished = true + return "" + } + writer.reset() gen.write(toRow(iter.next())) gen.flush() - val json = writer.toString - if (hasNext) { - writer.reset() - } else { - gen.close() - } - - json + writer.toString } } } (Encoders.STRING) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DescribeRelationJsonCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DescribeRelationJsonCommand.scala index 0607a8593fbb8..1aba10e40c364 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DescribeRelationJsonCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DescribeRelationJsonCommand.scala @@ -239,23 +239,17 @@ case class DescribeRelationJsonCommand( private def describeClusteringInfoJson( table: CatalogTable, jsonMap: mutable.LinkedHashMap[String, JValue]): Unit = { table.clusterBySpec.foreach { clusterBySpec => - val clusteringColumnsJson: JValue = JArray( - clusterBySpec.columnNames.map { fieldNames => - val nestedFieldOpt = table.schema.findNestedField(fieldNames.fieldNames.toIndexedSeq) - assert(nestedFieldOpt.isDefined, - "The clustering column " + - s"${fieldNames.fieldNames.map(quoteIfNeeded).mkString(".")} " + - s"was not found in the table schema ${table.schema.catalogString}." - ) - val (path, field) = nestedFieldOpt.get - JObject( - "name" -> JString((path :+ field.name).map(quoteIfNeeded).mkString(".")), - "type" -> jsonType(field.dataType), - "comment" -> field.getComment().map(JString).getOrElse(JNull) - ) - }.toList - ) - addKeyValueToMap("clustering_information", clusteringColumnsJson, jsonMap) + val clusteringColumnsJson = JArray(clusterBySpec.columnNames.map { fieldNames => + val nestedFieldOpt = table.schema.findNestedField(fieldNames.fieldNames.toIndexedSeq) + assert(nestedFieldOpt.isDefined, + "The clustering column " + + s"${fieldNames.fieldNames.map(quoteIfNeeded).mkString(".")} " + + s"was not found in the table schema ${table.schema.catalogString}." + ) + JString(fieldNames.fieldNames.map(quoteIfNeeded).mkString(".")) + }.toList) + + addKeyValueToMap("clustering_columns", clusteringColumnsJson, jsonMap) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala index e9cc23c6a5bab..5960cf8c38ced 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala @@ -289,13 +289,13 @@ object PushVariantIntoScan extends Rule[LogicalPlan] { relation: LogicalRelation, hadoopFsRelation: HadoopFsRelation): LogicalPlan = { val variants = new VariantInRelation - val defaultValues = ResolveDefaultColumns.existenceDefaultValues(hadoopFsRelation.schema) - // I'm not aware of any case that an attribute `relation.output` can have a different data type - // than the corresponding field in `hadoopFsRelation.schema`. Other code seems to prefer using - // the data type in `hadoopFsRelation.schema`, let's also stick to it. - val schemaWithAttributes = hadoopFsRelation.schema.fields.zip(relation.output) - for (((f, attr), defaultValue) <- schemaWithAttributes.zip(defaultValues)) { - variants.addVariantFields(attr.exprId, f.dataType, defaultValue, Nil) + + val schemaAttributes = relation.resolve(hadoopFsRelation.dataSchema, + hadoopFsRelation.sparkSession.sessionState.analyzer.resolver) + val defaultValues = ResolveDefaultColumns.existenceDefaultValues(StructType( + schemaAttributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))) + for ((a, defaultValue) <- schemaAttributes.zip(defaultValues)) { + variants.addVariantFields(a.exprId, a.dataType, defaultValue, Nil) } if (variants.mapping.isEmpty) return originalPlan @@ -304,24 +304,28 @@ object PushVariantIntoScan extends Rule[LogicalPlan] { // `collectRequestedFields` may have removed all variant columns. if (variants.mapping.forall(_._2.isEmpty)) return originalPlan - val (newFields, newOutput) = schemaWithAttributes.map { - case (f, attr) => - if (variants.mapping.get(attr.exprId).exists(_.nonEmpty)) { - val newType = variants.rewriteType(attr.exprId, f.dataType, Nil) - val newAttr = AttributeReference(f.name, newType, f.nullable, f.metadata)() - (f.copy(dataType = newType), newAttr) - } else { - (f, attr) - } - }.unzip + val attributeMap = schemaAttributes.map { a => + if (variants.mapping.get(a.exprId).exists(_.nonEmpty)) { + val newType = variants.rewriteType(a.exprId, a.dataType, Nil) + val newAttr = AttributeReference(a.name, newType, a.nullable, a.metadata)( + qualifier = a.qualifier) + (a.exprId, newAttr) + } else { + // `relation.resolve` actually returns `Seq[AttributeReference]`, although the return type + // is `Seq[Attribute]`. + (a.exprId, a.asInstanceOf[AttributeReference]) + } + }.toMap + val newFields = schemaAttributes.map { a => + val dataType = attributeMap(a.exprId).dataType + StructField(a.name, dataType, a.nullable, a.metadata) + } + val newOutput = relation.output.map(a => attributeMap.getOrElse(a.exprId, a)) val newHadoopFsRelation = hadoopFsRelation.copy(dataSchema = StructType(newFields))( hadoopFsRelation.sparkSession) val newRelation = relation.copy(relation = newHadoopFsRelation, output = newOutput.toIndexedSeq) - val attributeMap = relation.output.zip(newOutput).map { - case (oldAttr, newAttr) => oldAttr.exprId -> newAttr - }.toMap val withFilter = if (filters.nonEmpty) { Filter(filters.map(variants.rewriteExpr(_, attributeMap)).reduce(And), newRelation) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 6bd49e75af241..2fefd8f70cd5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, FunctionUtils, LogicalGroupState} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.execution.python.BatchIterator import org.apache.spark.sql.execution.r.ArrowRRunner import org.apache.spark.sql.execution.streaming.GroupStateImpl import org.apache.spark.sql.internal.SQLConf @@ -218,13 +219,17 @@ case class MapPartitionsInRWithArrowExec( child: SparkPlan) extends UnaryExecNode { override def producedAttributes: AttributeSet = AttributeSet(output) + private val batchSize = conf.arrowMaxRecordsPerBatch + override def outputPartitioning: Partitioning = child.outputPartitioning override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitionsInternal { inputIter => val outputTypes = schema.map(_.dataType) - val batchIter = Iterator(inputIter) + // DO NOT use iter.grouped(). See BatchIterator. + val batchIter = + if (batchSize > 0) new BatchIterator(inputIter, batchSize) else Iterator(inputIter) val runner = new ArrowRRunner(func, packageNames, broadcastVars, inputSchema, SQLConf.get.sessionLocalTimeZone, RRunnerModes.DATAFRAME_DAPPLY) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala index 3b7b2c56744a8..88b63f3b2dd09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala @@ -69,7 +69,7 @@ class MapInBatchEvaluatorFactory( pythonRunnerConf, pythonMetrics, jobArtifactUUID, - None) + None) with BatchedPythonArrowInput val columnarBatchIter = pyRunner.compute(batchIter, context.partitionId(), context) val unsafeProj = UnsafeProjection.create(output, output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala index adbfa341f3d5d..05fd571e0265b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala @@ -145,7 +145,10 @@ private[python] trait BasicPythonArrowInput extends PythonArrowInput[Iterator[In private[python] trait BatchedPythonArrowInput extends BasicPythonArrowInput { self: BasePythonRunner[Iterator[InternalRow], _] => - private val arrowMaxRecordsPerBatch = SQLConf.get.arrowMaxRecordsPerBatch + private val arrowMaxRecordsPerBatch = { + val v = SQLConf.get.arrowMaxRecordsPerBatch + if (v > 0) v else Int.MaxValue + } private val maxBytesPerBatch = SQLConf.get.arrowMaxBytesPerBatch diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 820322d1e0ee1..1c04398a4ed79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -671,16 +671,15 @@ class RocksDB( if (useColumnFamilies) { changelogReader.foreach { case (recordType, key, value) => - val (keyWithoutPrefix, cfName) = decodeStateRowWithPrefix(key) recordType match { case RecordType.PUT_RECORD => - put(keyWithoutPrefix, value, cfName) + put(key, value, includesPrefix = true) case RecordType.DELETE_RECORD => - remove(keyWithoutPrefix, cfName) + remove(key, includesPrefix = true) case RecordType.MERGE_RECORD => - merge(keyWithoutPrefix, value, cfName) + merge(key, value, includesPrefix = true) } } } else { @@ -801,8 +800,9 @@ class RocksDB( def put( key: Array[Byte], value: Array[Byte], - cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = { - val keyWithPrefix = if (useColumnFamilies) { + cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME, + includesPrefix: Boolean = false): Unit = { + val keyWithPrefix = if (useColumnFamilies && !includesPrefix) { encodeStateRowWithPrefix(key, cfName) } else { key @@ -827,8 +827,9 @@ class RocksDB( def merge( key: Array[Byte], value: Array[Byte], - cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = { - val keyWithPrefix = if (useColumnFamilies) { + cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME, + includesPrefix: Boolean = false): Unit = { + val keyWithPrefix = if (useColumnFamilies && !includesPrefix) { encodeStateRowWithPrefix(key, cfName) } else { key @@ -843,8 +844,11 @@ class RocksDB( * Remove the key if present. * @note This update is not committed to disk until commit() is called. */ - def remove(key: Array[Byte], cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = { - val keyWithPrefix = if (useColumnFamilies) { + def remove( + key: Array[Byte], + cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME, + includesPrefix: Boolean = false): Unit = { + val keyWithPrefix = if (useColumnFamilies && !includesPrefix) { encodeStateRowWithPrefix(key, cfName) } else { key diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index a5171289297d1..25322f4db189f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -971,6 +971,14 @@ class ForStatementExec( createSetVarExec(colName, expr) ) } + val varScopeLabel = variableName.map(varName => + if (session.sessionState.conf.caseSensitiveAnalysis) { + varName + } else { + varName.toLowerCase(Locale.ROOT) + } + ).orElse(Some(UUID.randomUUID().toString.toLowerCase(Locale.ROOT))) + bodyWithVariables = new CompoundBodyExec( // NoOpStatementExec appended to end of body to prevent // dropping variables before last statement is executed. @@ -979,7 +987,7 @@ class ForStatementExec( statements = variableInitStatements ++ statements :+ new NoOpStatementExec, // We generate label name if FOR variable is not specified, similar to how // compound bodies have generated label names if label is not specified. - label = variableName.orElse(Some(UUID.randomUUID().toString.toLowerCase(Locale.ROOT))), + label = varScopeLabel, isScope = true, context = context, triggerToExceptionHandlerMap = TriggerToExceptionHandlerMap.createEmptyMap() diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/describe.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/describe.sql.out index c13dd53023876..2f7237663b646 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/describe.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/describe.sql.out @@ -393,6 +393,27 @@ DESC FORMATTED e DescribeTableCommand `spark_catalog`.`default`.`e`, true, [col_name#x, data_type#x, comment#x] +-- !query +CREATE TABLE f USING json PARTITIONED BY (B, C) AS SELECT 'APACHE' A, CAST('SPARK' AS BINARY) B, TIMESTAMP'2018-11-17 13:33:33' C +-- !query analysis +CreateDataSourceTableAsSelectCommand `spark_catalog`.`default`.`f`, ErrorIfExists, [A, B, C] + +- Project [APACHE AS A#x, cast(SPARK as binary) AS B#x, 2018-11-17 13:33:33 AS C#x] + +- OneRowRelation + + +-- !query +DESC FORMATTED f PARTITION (B='SPARK', C=TIMESTAMP'2018-11-17 13:33:33') +-- !query analysis +DescribeTableCommand `spark_catalog`.`default`.`f`, [B=SPARK, C=2018-11-17 13:33:33], true, [col_name#x, data_type#x, comment#x] + + +-- !query +DESC TABLE EXTENDED f PARTITION (B='SPARK', C=TIMESTAMP'2018-11-17 13:33:33') AS JSON +-- !query analysis +DescribeRelationJsonCommand [B=SPARK, C=2018-11-17 13:33:33], true, [json_metadata#x] ++- ResolvedTable V2SessionCatalog(spark_catalog), default.f, V1Table(default.f), [A#x, B#x, C#x] + + -- !query DROP VIEW temp_v -- !query analysis @@ -430,3 +451,10 @@ DROP TABLE e -- !query analysis DropTable false, false +- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.e + + +-- !query +DROP TABLE f +-- !query analysis +DropTable false, false ++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.f diff --git a/sql/core/src/test/resources/sql-tests/inputs/describe.sql b/sql/core/src/test/resources/sql-tests/inputs/describe.sql index dd7f7c40ed524..dbe5bc840bce6 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/describe.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/describe.sql @@ -122,6 +122,12 @@ DESC TABLE EXTENDED e; DESC FORMATTED e; +CREATE TABLE f USING json PARTITIONED BY (B, C) AS SELECT 'APACHE' A, CAST('SPARK' AS BINARY) B, TIMESTAMP'2018-11-17 13:33:33' C; + +DESC FORMATTED f PARTITION (B='SPARK', C=TIMESTAMP'2018-11-17 13:33:33'); + +DESC TABLE EXTENDED f PARTITION (B='SPARK', C=TIMESTAMP'2018-11-17 13:33:33') AS JSON; + -- DROP TEST TABLES/VIEWS DROP VIEW temp_v; @@ -135,3 +141,5 @@ DROP TABLE t; DROP TABLE d; DROP TABLE e; + +DROP TABLE f; diff --git a/sql/core/src/test/resources/sql-tests/results/describe.sql.out b/sql/core/src/test/resources/sql-tests/results/describe.sql.out index da53ece198daa..d5a3dd50dc7ec 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe.sql.out @@ -890,6 +890,48 @@ a string CONCAT('a\n b\n ', 'c\n d') b int 42 +-- !query +CREATE TABLE f USING json PARTITIONED BY (B, C) AS SELECT 'APACHE' A, CAST('SPARK' AS BINARY) B, TIMESTAMP'2018-11-17 13:33:33' C +-- !query schema +struct<> +-- !query output + + + +-- !query +DESC FORMATTED f PARTITION (B='SPARK', C=TIMESTAMP'2018-11-17 13:33:33') +-- !query schema +struct +-- !query output +A string +B binary +C timestamp +# Partition Information +# col_name data_type comment +B binary +C timestamp + +# Detailed Partition Information +Database default +Table f +Partition Values [B=SPARK, C=2018-11-17 13:33:33] +Location [not included in comparison]/{warehouse_dir}/f/B=SPARK/C=2018-11-17 13%3A33%3A33 +Partition Parameters [numFiles=1, totalSize=15, transient_lastDdlTime=[not included in comparison]] +Created Time [not included in comparison] +Last Access [not included in comparison] + +# Storage Information +Location [not included in comparison]/{warehouse_dir}/f + + +-- !query +DESC TABLE EXTENDED f PARTITION (B='SPARK', C=TIMESTAMP'2018-11-17 13:33:33') AS JSON +-- !query schema +struct +-- !query output +{"table_name":"f","catalog_name":"spark_catalog","namespace":["default"],"schema_name":"default","columns":[{"name":"A","type":{"name":"string"},"nullable":true},{"name":"B","type":{"name":"binary"},"nullable":true},{"name":"C","type":{"name":"timestamp_ltz"},"nullable":true}],"partition_values":{"B":"SPARK","C":"2018-11-17 13:33:33"},"location":"file:[not included in comparison]/{warehouse_dir}/f/B=SPARK/C=2018-11-17 13%3A33%3A33","partition_parameters":{"numFiles":"1","totalSize":"15","transient_lastDdlTime [not included in comparison]":"None"},"created_time [not included in comparison]":"None","last_access [not included in comparison]":"None","created_by [not included in comparison]":"None","type":"MANAGED","provider":"json","partition_provider":"Catalog","partition_columns":["B","C"]} + + -- !query DROP VIEW temp_v -- !query schema @@ -936,3 +978,11 @@ DROP TABLE e struct<> -- !query output + + +-- !query +DROP TABLE f +-- !query schema +struct<> +-- !query output + diff --git a/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp-ansi.sql.out b/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp-ansi.sql.out index 933982f5cff6a..eaae069a2a4bd 100644 --- a/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp-ansi.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp-ansi.sql.out @@ -394,7 +394,9 @@ org.apache.spark.SparkDateTimeException "queryContext" : [ { "objectType" : "", "objectName" : "", - "fragment" : "" + "startIndex" : 8, + "stopIndex" : 22, + "fragment" : "to_timestamp(1)" } ] } diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/stringCastAndExpressions.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/stringCastAndExpressions.sql.out index ea17c7f5289fc..3d42e866ea17f 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/stringCastAndExpressions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/stringCastAndExpressions.sql.out @@ -344,7 +344,9 @@ org.apache.spark.SparkDateTimeException "queryContext" : [ { "objectType" : "", "objectName" : "", - "fragment" : "" + "startIndex" : 8, + "stopIndex" : 22, + "fragment" : "to_timestamp(a)" } ] } diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-union.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-union.sql.out index 7f2931c3ade0a..c61a6f841c880 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/udf-union.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-union.sql.out @@ -51,9 +51,9 @@ org.apache.spark.SparkNumberFormatException "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 1, - "stopIndex" : 243, - "fragment" : "SELECT udf(c1) as c1, udf(c2) as c2\nFROM (SELECT udf(c1) as c1, udf(c2) as c2 FROM t1 WHERE c2 = 'a'\n UNION ALL\n SELECT udf(c1) as c1, udf(c2) as c2 FROM t2\n UNION ALL\n SELECT udf(c1) as c1, udf(c2) as c2 FROM t2)" + "startIndex" : 45, + "stopIndex" : 172, + "fragment" : "SELECT udf(c1) as c1, udf(c2) as c2 FROM t1 WHERE c2 = 'a'\n UNION ALL\n SELECT udf(c1) as c1, udf(c2) as c2 FROM t2" } ] } diff --git a/sql/core/src/test/resources/sql-tests/results/union.sql.out b/sql/core/src/test/resources/sql-tests/results/union.sql.out index d7db2163f8b87..84e21a535f35d 100644 --- a/sql/core/src/test/resources/sql-tests/results/union.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/union.sql.out @@ -51,9 +51,9 @@ org.apache.spark.SparkNumberFormatException "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 1, - "stopIndex" : 133, - "fragment" : "SELECT *\nFROM (SELECT * FROM t1 where c1 = 1\n UNION ALL\n SELECT * FROM t2\n UNION ALL\n SELECT * FROM t2)" + "startIndex" : 18, + "stopIndex" : 89, + "fragment" : "SELECT * FROM t1 where c1 = 1\n UNION ALL\n SELECT * FROM t2" } ] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 7844acc8bdcfd..81713c777bc22 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -36,9 +36,11 @@ import org.apache.spark.TestUtils.withListener import org.apache.spark.internal.config.MAX_RESULT_SIZE import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, ScroogeLikeExample} -import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoders, ExpressionEncoder, OuterScopes} -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.BoxedIntEncoder -import org.apache.spark.sql.catalyst.expressions.{CodegenObjectFactoryMode, GenericRowWithSchema} +import org.apache.spark.sql.catalyst.DeserializerBuildHelper.createDeserializerForString +import org.apache.spark.sql.catalyst.SerializerBuildHelper.createSerializerForString +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoders, AgnosticExpressionPathEncoder, ExpressionEncoder, OuterScopes} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BoxedIntEncoder, ProductEncoder} +import org.apache.spark.sql.catalyst.expressions.{CodegenObjectFactoryMode, Expression, GenericRowWithSchema} import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.trees.DataFrameQueryContext import org.apache.spark.sql.catalyst.util.sideBySide @@ -2819,6 +2821,21 @@ class DatasetSuite extends QueryTest } } + test("SPARK-49960: joinWith custom encoder") { + /* + test based on "joinWith class with primitive, toDF" + with "custom" encoder. Removing the use of AgnosticExpressionPathEncoder + within SerializerBuildHelper and DeserializerBuildHelper will trigger MatchErrors + */ + val ds1 = Seq(1, 1, 2).toDS() + val ds2 = spark.createDataset[ClassData](Seq(ClassData("a", 1), + ClassData("b", 2)))(CustomPathEncoder.custClassDataEnc) + + checkAnswer( + ds1.joinWith(ds2, $"value" === $"b").toDF().select($"_1", $"_2.a", $"_2.b"), + Row(1, "a", 1) :: Row(1, "a", 1) :: Row(2, "b", 2) :: Nil) + } + test("SPARK-49961: transform type should be consistent (classic)") { val ds = Seq(1, 2).toDS() val f: classic.Dataset[Int] => classic.Dataset[Int] = @@ -2858,6 +2875,38 @@ class DatasetSuite extends QueryTest } } +/** + * SPARK-49960 - Mimic a custom encoder such as those provided by typelevel Frameless + */ +object CustomPathEncoder { + + val realClassDataEnc: ProductEncoder[ClassData] = + Encoders.product[ClassData].asInstanceOf[ProductEncoder[ClassData]] + + val custStringEnc: AgnosticExpressionPathEncoder[String] = + new AgnosticExpressionPathEncoder[String] { + + override def toCatalyst(input: Expression): Expression = + createSerializerForString(input) + + override def fromCatalyst(inputPath: Expression): Expression = + createDeserializerForString(inputPath, returnNullable = false) + + override def isPrimitive: Boolean = false + + override def dataType: DataType = StringType + + override def clsTag: ClassTag[String] = implicitly[ClassTag[String]] + + override def isStruct: Boolean = true + } + + val custClassDataEnc: ProductEncoder[ClassData] = realClassDataEnc.copy(fields = + Seq(realClassDataEnc.fields.head.copy(enc = custStringEnc), + realClassDataEnc.fields.last) + ) +} + class DatasetLargeResultCollectingSuite extends QueryTest with SharedSparkSession { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala index 04f274e4af592..7cc556857774f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala @@ -66,6 +66,9 @@ trait SQLQueryTestHelper extends Logging { s""""location": "$notIncludedMsg/{warehouse_dir}/""") .replaceAll(s""""created_by":".*?"""", s""""created_by $notIncludedMsg":"None"""") .replaceAll(s""""created_time":".*?"""", s""""created_time $notIncludedMsg":"None"""") + .replaceAll(s"transient_lastDdlTime=\\d+", s"transient_lastDdlTime=$notIncludedMsg") + .replaceAll(s""""transient_lastDdlTime":"\\d+"""", + s""""transient_lastDdlTime $notIncludedMsg":"None"""") .replaceAll(s""""last_access":".*?"""", s""""last_access $notIncludedMsg":"None"""") .replaceAll(s""""owner":".*?"""", s""""owner $notIncludedMsg":"None"""") .replaceAll(s""""partition_statistics":"\\d+"""", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/DescribeTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/DescribeTableSuite.scala index a33782de8809f..edace9b38ce71 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/DescribeTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/DescribeTableSuite.scala @@ -289,7 +289,8 @@ trait DescribeTableSuiteBase extends command.DescribeTableSuiteBase "max_records" -> "1000" )), partition_provider = Some("Catalog"), - partition_columns = Some(List("department", "hire_date")) + partition_columns = Some(List("department", "hire_date")), + clustering_columns = None // no cluster spec for "CLUSTERED BY" ) assert(parsedOutput.location.isDefined) @@ -364,6 +365,65 @@ trait DescribeTableSuiteBase extends command.DescribeTableSuiteBase } } + test("DESCRIBE AS JSON cluster spec") { + withNamespaceAndTable("ns", "table") { t => + val tableCreationStr = + s""" + |CREATE TABLE $t ( + | id INT, + | name STRING, + | region STRING, + | category STRING + |) USING parquet + |COMMENT 'test cluster spec' + |CLUSTER BY (id, name) + |TBLPROPERTIES ('t' = 'test') + |""".stripMargin + spark.sql(tableCreationStr) + + val descriptionDf = + spark.sql(s"DESCRIBE FORMATTED $t AS JSON") + val firstRow = descriptionDf.select("json_metadata").head() + val jsonValue = firstRow.getString(0) + val parsedOutput = parse(jsonValue).extract[DescribeTableJson] + + val expectedOutput = DescribeTableJson( + table_name = Some("table"), + catalog_name = Some("spark_catalog"), + namespace = Some(List("ns")), + schema_name = Some("ns"), + columns = Some(List( + TableColumn("id", Type("int"), true), + TableColumn("name", Type("string"), true), + TableColumn("region", Type("string"), true), + TableColumn("category", Type("string"), true) + )), + last_access = Some("UNKNOWN"), + created_by = Some(s"Spark $SPARK_VERSION"), + `type` = Some("MANAGED"), + provider = Some("parquet"), + bucket_columns = Some(Nil), + sort_columns = Some(Nil), + comment = Some("test cluster spec"), + table_properties = Some(Map( + "t" -> "test", + "clusteringColumns" -> "[[\"id\"],[\"name\"]]" + )), + serde_library = if (getProvider() == "hive") { + Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe") + } else { + None + }, + clustering_columns = Some(List("id", "name")) + ) + + assert(parsedOutput.location.isDefined) + assert(iso8601Regex.matches(parsedOutput.created_time.get)) + assert(expectedOutput == parsedOutput.copy( + location = None, created_time = None, storage_properties = None)) + } + } + test("DESCRIBE AS JSON default values") { withNamespaceAndTable("ns", "table") { t => val tableCreationStr = @@ -756,6 +816,7 @@ case class DescribeTableJson( partition_provider: Option[String] = None, partition_columns: Option[List[String]] = Some(Nil), partition_values: Option[Map[String, String]] = None, + clustering_columns: Option[List[String]] = None, statistics: Option[Map[String, Any]] = None, view_text: Option[String] = None, view_original_text: Option[String] = None, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScanSuite.scala index 5515c4053bc1c..08a9a306eec32 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScanSuite.scala @@ -77,6 +77,17 @@ class PushVariantIntoScanSuite extends SharedSparkSession { case _ => fail() } + // Validate _metadata works. + sql("select variant_get(v, '$.a', 'int') as a, _metadata from T") + .queryExecution.optimizedPlan match { + case Project(projectList, l: LogicalRelation) => + val output = l.output + val v = output(0) + checkAlias(projectList(0), "a", GetStructField(v, 0)) + assert(projectList(1).dataType.isInstanceOf[StructType]) + case _ => fail() + } + sql("select 1 from T where isnotnull(v)") .queryExecution.optimizedPlan match { case Project(projectList, Filter(condition, l: LogicalRelation)) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala index 50240c0605e83..475f8b1168302 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala @@ -1196,6 +1196,95 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession } } + testWithColumnFamilies( + "RocksDB: test includesPrefix parameter during changelog replay", + TestWithChangelogCheckpointingEnabled) { colFamiliesEnabled => + + // Only test when column families are enabled, as the includesPrefix parameter + // is only relevant in that case + if (colFamiliesEnabled) { + val remoteDir = Utils.createTempDir().toString + val conf = dbConf.copy(minDeltasForSnapshot = 3, compactOnCommit = false) + new File(remoteDir).delete() // to make sure that the directory gets created + + withDB(remoteDir, conf = conf, useColumnFamilies = true) { db => + // Create a test column family + val testCfName = "test_cf" + db.createColFamilyIfAbsent(testCfName, isInternal = false) + + // Write initial data + db.load(0) + db.put("key1", "value1", StateStore.DEFAULT_COL_FAMILY_NAME) + db.put("key2", "value2", testCfName) + db.commit() + + // Get the encoded keys with column family prefixes + val keyWithPrefix1 = getKeyWithPrefix(db, "key1", StateStore.DEFAULT_COL_FAMILY_NAME) + val keyWithPrefix2 = getKeyWithPrefix(db, "key2", testCfName) + + // Pretend we're replaying changelog with already-prefixed keys + // Throughout this test, we will load version 0 and the latest version + // in order to ensure that the changelog files are read from and + // replayed + db.load(0) + db.load(1) + + // Use the includesPrefix=true parameter with keys that already have prefixes + db.put(keyWithPrefix1, "updated1", includesPrefix = true) + db.put(keyWithPrefix2, "updated2", includesPrefix = true) + db.commit() + + // Verify the updates were applied correctly + db.load(0) + db.load(2) + assert(toStr(db.get("key1", StateStore.DEFAULT_COL_FAMILY_NAME)) === "updated1") + assert(toStr(db.get("key2", testCfName)) === "updated2") + + // Test remove with includesPrefix + db.remove(keyWithPrefix1, includesPrefix = true) + db.remove(keyWithPrefix2, includesPrefix = true) + db.commit() + + // Verify removals worked + db.load(0) + db.load(3) + assert(db.get("key1", StateStore.DEFAULT_COL_FAMILY_NAME) === null) + assert(db.get("key2", testCfName) === null) + + // Add back some data for testing merge operation + db.put("merge_key", "base", StateStore.DEFAULT_COL_FAMILY_NAME) + db.commit() + + // Get encoded key for merge test + val mergeKeyWithPrefix = getKeyWithPrefix( + db, "merge_key", StateStore.DEFAULT_COL_FAMILY_NAME) + + // Test merge with includesPrefix + db.load(0) + db.load(4) + db.merge(mergeKeyWithPrefix, "appended", includesPrefix = true) + db.commit() + + // Verify merge operation worked + db.load(0) + db.load(5) + assert(toStr(db.get("merge_key", StateStore.DEFAULT_COL_FAMILY_NAME)) === "base,appended") + } + } + } + + // Helper method to get a key with column family prefix + private def getKeyWithPrefix(db: RocksDB, key: String, cfName: String): Array[Byte] = { + // This uses reflection to call the private encodeStateRowWithPrefix method + val encodeMethod = classOf[RocksDB].getDeclaredMethod( + "encodeStateRowWithPrefix", + classOf[Array[Byte]], + classOf[String] + ) + encodeMethod.setAccessible(true) + encodeMethod.invoke(db, key.getBytes, cfName).asInstanceOf[Array[Byte]] + } + testWithStateStoreCheckpointIdsAndColumnFamilies(s"RocksDB: get, put, iterator, commit, load", TestWithBothChangelogCheckpointingEnabledAndDisabled) { case (enableStateStoreCheckpointIds, colFamiliesEnabled) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index de64ec8ee9ef4..7a1b6b1c5e463 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -1846,6 +1846,33 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { verifySqlScriptResult(sqlScriptText, expected) } + test("for statement - mixed case variable names") { + val sqlScript = + """ + |BEGIN + | DECLARE sum INT = 0; + | FOR LoopCursor AS (SELECT * FROM VALUES (1), (2), (3) AS tbl(RowValue)) DO + | SET sum = sum + LoopCursor.RowValue; + | END FOR; + | SELECT sum; + |END + |""".stripMargin + val expected = Seq( + Seq.empty[Row], // declare sum + Seq.empty[Row], // declare RowValue + Seq.empty[Row], // set RowValue + Seq.empty[Row], // set sum + Seq.empty[Row], // declare RowValue + Seq.empty[Row], // set RowValue + Seq.empty[Row], // set sum + Seq.empty[Row], // declare RowValue + Seq.empty[Row], // set RowValue + Seq.empty[Row], // set sum + Seq(Row(6)) // select + ) + verifySqlScriptResult(sqlScript, expected) + } + test("for statement - enters body once") { withTable("t") { val sqlScript = diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala index cb85993e5e099..da44aa642dae5 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala @@ -34,6 +34,7 @@ class HiveMetastoreLazyInitializationSuite extends SparkFunSuite { .master("local[2]") .enableHiveSupport() .config("spark.hadoop.hive.metastore.uris", "thrift://127.0.0.1:11111") + .config("spark.hadoop.hive.thrift.client.max.message.size", "1gb") .getOrCreate() val originalLevel = LogManager.getRootLogger.asInstanceOf[Logger].getLevel val originalClassLoader = Thread.currentThread().getContextClassLoader diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index ca8b5369a9cbb..72fea06936a10 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -22,6 +22,7 @@ import java.lang.reflect.InvocationTargetException import java.util import java.util.Locale +import scala.annotation.tailrec import scala.collection.mutable import scala.util.control.NonFatal @@ -81,14 +82,18 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * Due to classloader isolation issues, pattern matching won't work here so we need * to compare the canonical names of the exceptions, which we assume to be stable. */ - private def isClientException(e: Throwable): Boolean = { - var temp: Class[_] = e.getClass - var found = false - while (temp != null && !found) { - found = clientExceptions.contains(temp.getCanonicalName) - temp = temp.getSuperclass - } - found + @tailrec + private def isClientException(e: Throwable): Boolean = e match { + case re: RuntimeException if re.getCause != null => + isClientException(re.getCause) + case e => + var temp: Class[_] = e.getClass + var found = false + while (temp != null && !found) { + found = clientExceptions.contains(temp.getCanonicalName) + temp = temp.getSuperclass + } + found } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 3e7e81d25d943..90f8a3a85d70c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -19,11 +19,12 @@ package org.apache.spark.sql.hive.client import java.io.{OutputStream, PrintStream} import java.lang.{Iterable => JIterable} -import java.lang.reflect.InvocationTargetException +import java.lang.reflect.{InvocationTargetException, Proxy => JdkProxy} import java.nio.charset.StandardCharsets.UTF_8 import java.util.{HashMap => JHashMap, Locale, Map => JMap} import java.util.concurrent.TimeUnit._ +import scala.annotation.tailrec import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ @@ -33,7 +34,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.metastore.{IMetaStoreClient, TableType => HiveTableType} +import org.apache.hadoop.hive.metastore.{HiveMetaStoreClient, IMetaStoreClient, RetryingMetaStoreClient, TableType => HiveTableType} import org.apache.hadoop.hive.metastore.api.{Database => HiveDatabase, Table => MetaStoreApiTable, _} import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.{Hive, HiveException, Partition => HivePartition, Table => HiveTable} @@ -43,7 +44,9 @@ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.MetadataTypedColumnsetSerDe import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe +import org.apache.hadoop.hive.thrift.TFilterTransport import org.apache.hadoop.security.UserGroupInformation +import org.apache.thrift.transport.{TEndpointTransport, TTransport} import org.apache.spark.{SparkConf, SparkException, SparkThrowable} import org.apache.spark.deploy.SparkHadoopUtil.SOURCE_SPARK @@ -1407,13 +1410,83 @@ private[hive] object HiveClientImpl extends Logging { case _ => new HiveConf(conf, classOf[HiveConf]) } - try { + val hive = try { Hive.getWithoutRegisterFns(hiveConf) } catch { // SPARK-37069: not all Hive versions have the above method (e.g., Hive 2.3.9 has it but - // 2.3.8 don't), therefore here we fallback when encountering the exception. + // 2.3.8 doesn't), therefore here we fallback when encountering the exception. case _: NoSuchMethodError => Hive.get(hiveConf) } + + // Follow behavior of HIVE-26633 (4.0.0), only apply the max message size when + // `hive.thrift.client.max.message.size` is set and the value is positive + Option(hiveConf.get("hive.thrift.client.max.message.size")) + .map(HiveConf.toSizeBytes(_).toInt).filter(_ > 0) + .foreach { maxMessageSize => + logDebug(s"Trying to set metastore client thrift max message to $maxMessageSize") + configureMaxThriftMessageSize(hiveConf, hive.getMSC, maxMessageSize) + } + + hive + } + + private def getFieldValue[T](obj: Any, fieldName: String): T = { + val field = obj.getClass.getDeclaredField(fieldName) + field.setAccessible(true) + field.get(obj).asInstanceOf[T] + } + + private def getFieldValue[T](obj: Any, clazz: Class[_], fieldName: String): T = { + val field = clazz.getDeclaredField(fieldName) + field.setAccessible(true) + field.get(obj).asInstanceOf[T] + } + + // SPARK-49489: a surgery for Hive 2.3.10 due to lack of HIVE-26633 + private def configureMaxThriftMessageSize( + hiveConf: HiveConf, msClient: IMetaStoreClient, maxMessageSize: Int): Unit = try { + msClient match { + // Hive uses Java Dynamic Proxy to enhance the MetaStoreClient to support synchronization + // and retrying, we should unwrap and access the underlying MetaStoreClient instance firstly + case proxy if JdkProxy.isProxyClass(proxy.getClass) => + JdkProxy.getInvocationHandler(proxy) match { + case syncHandler if syncHandler.getClass.getName.endsWith("SynchronizedHandler") => + val wrappedMsc = getFieldValue[IMetaStoreClient](syncHandler, "client") + configureMaxThriftMessageSize(hiveConf, wrappedMsc, maxMessageSize) + case retryHandler: RetryingMetaStoreClient => + val wrappedMsc = getFieldValue[IMetaStoreClient](retryHandler, "base") + configureMaxThriftMessageSize(hiveConf, wrappedMsc, maxMessageSize) + case _ => + } + case msc: HiveMetaStoreClient if !msc.isLocalMetaStore => + @tailrec + def configure(t: TTransport): Unit = t match { + // Unwrap and access the underlying TTransport when security enabled (Kerberos) + case tTransport: TFilterTransport => + val wrappedTTransport = getFieldValue[TTransport]( + tTransport, classOf[TFilterTransport], "wrapped") + configure(wrappedTTransport) + case tTransport: TEndpointTransport => + val tConf = tTransport.getConfiguration + val currentMaxMessageSize = tConf.getMaxMessageSize + if (currentMaxMessageSize != maxMessageSize) { + logDebug("Change the current metastore client thrift max message size from " + + s"$currentMaxMessageSize to $maxMessageSize") + tConf.setMaxMessageSize(maxMessageSize) + // This internally call TEndpointTransport#resetConsumedMessageSize(-1L) to + // apply the updated maxMessageSize + tTransport.updateKnownMessageSize(0L) + } + case _ => + } + configure(msc.getTTransport) + case _ => // do nothing + } + } catch { + // TEndpointTransport is added in THRIFT-5237 (0.14.0), for Hive versions that use older + // Thrift library (e.g. Hive 2.3.9 uses Thrift 0.9.3), which aren't affected by THRIFT-5237 + // and don't need to apply HIVE-26633 + case _: NoClassDefFoundError => // do nothing } }