diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 24d0a5ac7262f..092a72ade1ea5 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql import java.net.URI +import java.nio.file.{Files, Paths} import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.{AtomicLong, AtomicReference} import scala.jdk.CollectionConverters._ import scala.reflect.runtime.universe.TypeTag +import scala.util.Try import com.google.common.cache.{CacheBuilder, CacheLoader} import io.grpc.ClientInterceptor @@ -591,6 +593,10 @@ class SparkSession private[sql] ( object SparkSession extends Logging { private val MAX_CACHED_SESSIONS = 100 private val planIdGenerator = new AtomicLong + private var server: Option[Process] = None + private[sql] val sparkOptions = sys.props.filter { p => + p._1.startsWith("spark.") && p._2.nonEmpty + }.toMap private val sessions = CacheBuilder .newBuilder() @@ -623,6 +629,51 @@ object SparkSession extends Logging { } } + /** + * Create a new Spark Connect server to connect locally. + */ + private[sql] def withLocalConnectServer[T](f: => T): T = { + synchronized { + val remoteString = sparkOptions + .get("spark.remote") + .orElse(Option(System.getProperty("spark.remote"))) // Set from Spark Submit + .orElse(sys.env.get(SparkConnectClient.SPARK_REMOTE)) + + val maybeConnectScript = + Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh")) + + if (server.isEmpty && + remoteString.exists(_.startsWith("local")) && + maybeConnectScript.exists(Files.exists(_))) { + server = Some { + val args = + Seq(maybeConnectScript.get.toString, "--master", remoteString.get) ++ sparkOptions + .filter(p => !p._1.startsWith("spark.remote")) + .flatMap { case (k, v) => Seq("--conf", s"$k=$v") } + val pb = new ProcessBuilder(args: _*) + // So don't exclude spark-sql jar in classpath + pb.environment().remove(SparkConnectClient.SPARK_REMOTE) + pb.start() + } + + // Let the server start. We will directly request to set the configurations + // and this sleep makes less noisy with retries. + Thread.sleep(2000L) + System.setProperty("spark.remote", "sc://localhost") + + // scalastyle:off runtimeaddshutdownhook + Runtime.getRuntime.addShutdownHook(new Thread() { + override def run(): Unit = if (server.isDefined) { + new ProcessBuilder(maybeConnectScript.get.toString) + .start() + } + }) + // scalastyle:on runtimeaddshutdownhook + } + } + f + } + /** * Create a new [[SparkSession]] based on the connect client [[Configuration]]. */ @@ -765,6 +816,16 @@ object SparkSession extends Logging { } private def applyOptions(session: SparkSession): Unit = { + // Only attempts to set Spark SQL configurations. + // If the configurations are static, it might throw an exception so + // simply ignore it for now. + sparkOptions + .filter { case (k, _) => + k.startsWith("spark.sql.") + } + .foreach { case (key, value) => + Try(session.conf.set(key, value)) + } options.foreach { case (key, value) => session.conf.set(key, value) } @@ -787,7 +848,7 @@ object SparkSession extends Logging { * * @since 3.5.0 */ - def create(): SparkSession = { + def create(): SparkSession = withLocalConnectServer { val session = tryCreateSessionFromClient() .getOrElse(SparkSession.this.create(builder.configuration)) setDefaultAndActiveSession(session) @@ -807,7 +868,7 @@ object SparkSession extends Logging { * * @since 3.5.0 */ - def getOrCreate(): SparkSession = { + def getOrCreate(): SparkSession = withLocalConnectServer { val session = tryCreateSessionFromClient() .getOrElse({ var existingSession = sessions.get(builder.configuration) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala index 86775803a0937..63fa2821a6c6a 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala @@ -17,10 +17,8 @@ package org.apache.spark.sql.application import java.io.{InputStream, OutputStream} -import java.nio.file.Paths import java.util.concurrent.Semaphore -import scala.util.Try import scala.util.control.NonFatal import ammonite.compiler.CodeClassWrapper @@ -34,6 +32,7 @@ import ammonite.util.Util.newLine import org.apache.spark.SparkBuildInfo.spark_version import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.SparkSession.withLocalConnectServer import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkConnectClientParser} /** @@ -64,37 +63,7 @@ Spark session available as 'spark'. semaphore: Option[Semaphore] = None, inputStream: InputStream = System.in, outputStream: OutputStream = System.out, - errorStream: OutputStream = System.err): Unit = { - val configs: Map[String, String] = - sys.props - .filter(p => - p._1.startsWith("spark.") && - p._2.nonEmpty && - // Don't include spark.remote that we manually set later. - !p._1.startsWith("spark.remote")) - .toMap - - val remoteString: Option[String] = - Option(System.getProperty("spark.remote")) // Set from Spark Submit - .orElse(sys.env.get(SparkConnectClient.SPARK_REMOTE)) - - if (remoteString.exists(_.startsWith("local"))) { - server = Some { - val args = Seq( - Paths.get(sparkHome, "sbin", "start-connect-server.sh").toString, - "--master", - remoteString.get) ++ configs.flatMap { case (k, v) => Seq("--conf", s"$k=$v") } - val pb = new ProcessBuilder(args: _*) - // So don't exclude spark-sql jar in classpath - pb.environment().remove(SparkConnectClient.SPARK_REMOTE) - pb.start() - } - // Let the server start. We will directly request to set the configurations - // and this sleep makes less noisy with retries. - Thread.sleep(2000L) - System.setProperty("spark.remote", "sc://localhost") - } - + errorStream: OutputStream = System.err): Unit = withLocalConnectServer { // Build the client. val client = try { @@ -118,13 +87,6 @@ Spark session available as 'spark'. // Build the session. val spark = SparkSession.builder().client(client).getOrCreate() - - // The configurations might not be all runtime configurations. - // Try to set them with ignoring failures for now. - configs - .filter(_._1.startsWith("spark.sql")) - .foreach { case (k, v) => Try(spark.conf.set(k, v)) } - val sparkBind = new Bind("spark", spark) // Add the proper imports and register a [[ClassFinder]]. @@ -197,18 +159,12 @@ Spark session available as 'spark'. } } } - try { - if (semaphore.nonEmpty) { - // Used for testing. - main.run(sparkBind, new Bind[Semaphore]("semaphore", semaphore.get)) - } else { - main.run(sparkBind) - } - } finally { - if (server.isDefined) { - new ProcessBuilder(Paths.get(sparkHome, "sbin", "stop-connect-server.sh").toString) - .start() - } + + if (semaphore.nonEmpty) { + // Used for testing. + main.run(sparkBind, new Bind[Semaphore]("semaphore", semaphore.get)) + } else { + main.run(sparkBind) } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index 667bf8bbc9754..e9507fa6bee48 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -25,6 +25,9 @@ import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ import scala.util.Try +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} + import org.apache.spark.{SparkConf, SparkUserAppException} import org.apache.spark.api.python.{Py4JServer, PythonUtils} import org.apache.spark.internal.config._ @@ -50,18 +53,21 @@ object PythonRunner { val formattedPythonFile = formatPath(pythonFile) val formattedPyFiles = resolvePyFiles(formatPaths(pyFiles)) - val gatewayServer = new Py4JServer(sparkConf) + var gatewayServer: Option[Py4JServer] = None + if (sparkConf.getOption("spark.remote").isEmpty) { + gatewayServer = Some(new Py4JServer(sparkConf)) - val thread = new Thread(() => Utils.logUncaughtExceptions { gatewayServer.start() }) - thread.setName("py4j-gateway-init") - thread.setDaemon(true) - thread.start() + val thread = new Thread(() => Utils.logUncaughtExceptions { gatewayServer.get.start() }) + thread.setName("py4j-gateway-init") + thread.setDaemon(true) + thread.start() - // Wait until the gateway server has started, so that we know which port is it bound to. - // `gatewayServer.start()` will start a new thread and run the server code there, after - // initializing the socket, so the thread started above will end as soon as the server is - // ready to serve connections. - thread.join() + // Wait until the gateway server has started, so that we know which port is it bound to. + // `gatewayServer.start()` will start a new thread and run the server code there, after + // initializing the socket, so the thread started above will end as soon as the server is + // ready to serve connections. + thread.join() + } // Build up a PYTHONPATH that includes the Spark assembly (where this class is), the // python directories in SPARK_HOME (if set), and any files in the pyFiles argument @@ -74,12 +80,22 @@ object PythonRunner { // Launch Python process val builder = new ProcessBuilder((Seq(pythonExec, formattedPythonFile) ++ otherArgs).asJava) val env = builder.environment() + if (sparkConf.getOption("spark.remote").nonEmpty) { + // For non-local remote, pass configurations to environment variables so + // Spark Connect client sets them. For local remotes, they will be set + // via Py4J. + val grouped = sparkConf.getAll.toMap.grouped(10).toSeq + env.put("PYSPARK_REMOTE_INIT_CONF_LEN", grouped.length.toString) + grouped.zipWithIndex.foreach { case (group, idx) => + env.put(s"PYSPARK_REMOTE_INIT_CONF_$idx", compact(render(group))) + } + } sparkConf.getOption("spark.remote").foreach(url => env.put("SPARK_REMOTE", url)) env.put("PYTHONPATH", pythonPath) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string - env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort) - env.put("PYSPARK_GATEWAY_SECRET", gatewayServer.secret) + gatewayServer.foreach(s => env.put("PYSPARK_GATEWAY_PORT", s.getListeningPort.toString)) + gatewayServer.foreach(s => env.put("PYSPARK_GATEWAY_SECRET", s.secret)) // pass conf spark.pyspark.python to python process, the only way to pass info to // python process is through environment variable. sparkConf.get(PYSPARK_PYTHON).foreach(env.put("PYSPARK_PYTHON", _)) @@ -103,7 +119,7 @@ object PythonRunner { throw new SparkUserAppException(exitCode) } } finally { - gatewayServer.shutdown() + gatewayServer.foreach(_.shutdown()) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 40d8eae644a07..ca81283e073ac 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -1802,6 +1802,23 @@ class SparkSubmitSuite val (_, classpath, _, _) = submit.prepareSubmitEnvironment(appArgs) assert(classpath.contains(".")) } + + // Requires Python dependencies for Spark Connect. Should be enabled by default. + ignore("Spark Connect application submission (Python)") { + val pyFile = File.createTempFile("remote_test", ".py") + pyFile.deleteOnExit() + val content = + "from pyspark.sql import SparkSession;" + + "spark = SparkSession.builder.getOrCreate();" + + "assert 'connect' in str(type(spark));" + + "assert spark.range(1).first()[0] == 0" + FileUtils.write(pyFile, content, StandardCharsets.UTF_8) + val args = Seq( + "--name", "testPyApp", + "--remote", "local", + pyFile.getAbsolutePath) + runSparkSubmit(args) + } } object JarCreationTest extends Logging { diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index 5826237f0c33a..a9d0e41aa3cbb 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -214,7 +214,9 @@ List buildClassPath(String appClassPath) throws IOException { addToClassPath(cp, f.toString()); } } - if (isRemote && "1".equals(getenv("SPARK_SCALA_SHELL"))) { + // If we're in 'spark.local.connect', it should create a Spark Classic Spark Context + // that launches Spark Connect server. + if (isRemote && System.getenv("SPARK_LOCAL_CONNECT") == null) { for (File f: new File(jarsDir).listFiles()) { // Exclude Spark Classic SQL and Spark Connect server jars // if we're in Spark Connect Shell. Also exclude Spark SQL API and diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java index eebd04fe4c5b1..8d95bc06d7a7d 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java @@ -82,10 +82,6 @@ public List buildCommand(Map env) javaOptsKeys.add("SPARK_BEELINE_OPTS"); yield "SPARK_BEELINE_MEMORY"; } - case "org.apache.spark.sql.application.ConnectRepl" -> { - isRemote = true; - yield "SPARK_DRIVER_MEMORY"; - } default -> "SPARK_DRIVER_MEMORY"; }; diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index e5246e893f658..cacb479229bb7 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -19,6 +19,7 @@ check_dependencies(__name__) +import json import threading import os import warnings @@ -200,6 +201,26 @@ def enableHiveSupport(self) -> "SparkSession.Builder": ) def _apply_options(self, session: "SparkSession") -> None: + init_opts = {} + for i in range(int(os.environ.get("PYSPARK_REMOTE_INIT_CONF_LEN", "0"))): + init_opts = json.loads(os.environ[f"PYSPARK_REMOTE_INIT_CONF_{i}"]) + + with self._lock: + for k, v in init_opts.items(): + # the options are applied after session creation, + # so following options always take no effect + if k not in [ + "spark.remote", + "spark.master", + ] and k.startswith("spark.sql."): + # Only attempts to set Spark SQL configurations. + # If the configurations are static, it might throw an exception so + # simply ignore it for now. + try: + session.conf.set(k, v) + except Exception: + pass + with self._lock: for k, v in self._options.items(): # the options are applied after session creation, @@ -993,10 +1014,17 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None: session = PySparkSession._instantiatedSession if session is None or session._sc._jsc is None: + init_opts = {} + for i in range(int(os.environ.get("PYSPARK_REMOTE_INIT_CONF_LEN", "0"))): + init_opts = json.loads(os.environ[f"PYSPARK_REMOTE_INIT_CONF_{i}"]) + init_opts.update(opts) + opts = init_opts + # Configurations to be overwritten overwrite_conf = opts overwrite_conf["spark.master"] = master overwrite_conf["spark.local.connect"] = "1" + os.environ["SPARK_LOCAL_CONNECT"] = "1" # Configurations to be set if unset. default_conf = {"spark.plugins": "org.apache.spark.sql.connect.SparkConnectPlugin"} @@ -1030,6 +1058,7 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None: finally: if origin_remote is not None: os.environ["SPARK_REMOTE"] = origin_remote + del os.environ["SPARK_LOCAL_CONNECT"] else: raise PySparkRuntimeError( errorClass="SESSION_OR_CONTEXT_EXISTS",