Skip to content

Commit

Permalink
[SPARK-48960][CONNECT] Makes spark-submit works with Spark connect
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR proposes to add the support of `--remote` at `bin/spark-submit` so it can use Spark Connect easily. This PR inclues:
- Make `bin/spark-submit` working with Scala Spark Connect client
- Pass `--conf` and loaded configurations to both Scala and Python Spark Connect clients

### Why are the changes needed?

`bin/pyspark --remote` already works. We should also make `bin/spark-submit` works in order for end users to try Spark Connect out and to have the consistent way.

### Does this PR introduce _any_ user-facing change?

Yes,
- `bin/spark-submit` supports `--remote` option in Scala.
- `bin/spark-submit` supports `--conf` and loaded Spark configurations to pass to the clients in Scala and Python

### How was this patch tested?

Python:

```bash
echo "from pyspark.sql import SparkSession;spark = SparkSession.builder.getOrCreate();assert 'connect' in str(type(spark));assert spark.range(1).first()[0] == 0" > test.py
```

```bash
./bin/spark-submit --name "testApp" --remote "local" test.py
```

Scala:

https://github.com/HyukjinKwon/spark-connect-example

```bash
git clone https://github.com/HyukjinKwon/spark-connect-example
cd spark-connect-example
build/sbt package
cd ..
git clone https://github.com/apache/spark.git
cd spark
build/sbt package
sbin/start-connect-server.sh
bin/spark-submit --name "testApp" --remote "sc://localhost" --class com.hyukjinkwon.SparkConnectExample ../spark-connect-example/target/scala-2.13/spark-connect-example_2.13-0.0.1.jar
```

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#47434 from HyukjinKwon/SPARK-48960.

Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
HyukjinKwon committed Sep 4, 2024
1 parent 90a236e commit 3e22c47
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
package org.apache.spark.sql

import java.net.URI
import java.nio.file.{Files, Paths}
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.{AtomicLong, AtomicReference}

import scala.jdk.CollectionConverters._
import scala.reflect.runtime.universe.TypeTag
import scala.util.Try

import com.google.common.cache.{CacheBuilder, CacheLoader}
import io.grpc.ClientInterceptor
Expand Down Expand Up @@ -591,6 +593,10 @@ class SparkSession private[sql] (
object SparkSession extends Logging {
private val MAX_CACHED_SESSIONS = 100
private val planIdGenerator = new AtomicLong
private var server: Option[Process] = None
private[sql] val sparkOptions = sys.props.filter { p =>
p._1.startsWith("spark.") && p._2.nonEmpty
}.toMap

private val sessions = CacheBuilder
.newBuilder()
Expand Down Expand Up @@ -623,6 +629,51 @@ object SparkSession extends Logging {
}
}

/**
* Create a new Spark Connect server to connect locally.
*/
private[sql] def withLocalConnectServer[T](f: => T): T = {
synchronized {
val remoteString = sparkOptions
.get("spark.remote")
.orElse(Option(System.getProperty("spark.remote"))) // Set from Spark Submit
.orElse(sys.env.get(SparkConnectClient.SPARK_REMOTE))

val maybeConnectScript =
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh"))

if (server.isEmpty &&
remoteString.exists(_.startsWith("local")) &&
maybeConnectScript.exists(Files.exists(_))) {
server = Some {
val args =
Seq(maybeConnectScript.get.toString, "--master", remoteString.get) ++ sparkOptions
.filter(p => !p._1.startsWith("spark.remote"))
.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
val pb = new ProcessBuilder(args: _*)
// So don't exclude spark-sql jar in classpath
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
pb.start()
}

// Let the server start. We will directly request to set the configurations
// and this sleep makes less noisy with retries.
Thread.sleep(2000L)
System.setProperty("spark.remote", "sc://localhost")

// scalastyle:off runtimeaddshutdownhook
Runtime.getRuntime.addShutdownHook(new Thread() {
override def run(): Unit = if (server.isDefined) {
new ProcessBuilder(maybeConnectScript.get.toString)
.start()
}
})
// scalastyle:on runtimeaddshutdownhook
}
}
f
}

/**
* Create a new [[SparkSession]] based on the connect client [[Configuration]].
*/
Expand Down Expand Up @@ -765,6 +816,16 @@ object SparkSession extends Logging {
}

private def applyOptions(session: SparkSession): Unit = {
// Only attempts to set Spark SQL configurations.
// If the configurations are static, it might throw an exception so
// simply ignore it for now.
sparkOptions
.filter { case (k, _) =>
k.startsWith("spark.sql.")
}
.foreach { case (key, value) =>
Try(session.conf.set(key, value))
}
options.foreach { case (key, value) =>
session.conf.set(key, value)
}
Expand All @@ -787,7 +848,7 @@ object SparkSession extends Logging {
*
* @since 3.5.0
*/
def create(): SparkSession = {
def create(): SparkSession = withLocalConnectServer {
val session = tryCreateSessionFromClient()
.getOrElse(SparkSession.this.create(builder.configuration))
setDefaultAndActiveSession(session)
Expand All @@ -807,7 +868,7 @@ object SparkSession extends Logging {
*
* @since 3.5.0
*/
def getOrCreate(): SparkSession = {
def getOrCreate(): SparkSession = withLocalConnectServer {
val session = tryCreateSessionFromClient()
.getOrElse({
var existingSession = sessions.get(builder.configuration)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@
package org.apache.spark.sql.application

import java.io.{InputStream, OutputStream}
import java.nio.file.Paths
import java.util.concurrent.Semaphore

import scala.util.Try
import scala.util.control.NonFatal

import ammonite.compiler.CodeClassWrapper
Expand All @@ -34,6 +32,7 @@ import ammonite.util.Util.newLine
import org.apache.spark.SparkBuildInfo.spark_version
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.SparkSession.withLocalConnectServer
import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkConnectClientParser}

/**
Expand Down Expand Up @@ -64,37 +63,7 @@ Spark session available as 'spark'.
semaphore: Option[Semaphore] = None,
inputStream: InputStream = System.in,
outputStream: OutputStream = System.out,
errorStream: OutputStream = System.err): Unit = {
val configs: Map[String, String] =
sys.props
.filter(p =>
p._1.startsWith("spark.") &&
p._2.nonEmpty &&
// Don't include spark.remote that we manually set later.
!p._1.startsWith("spark.remote"))
.toMap

val remoteString: Option[String] =
Option(System.getProperty("spark.remote")) // Set from Spark Submit
.orElse(sys.env.get(SparkConnectClient.SPARK_REMOTE))

if (remoteString.exists(_.startsWith("local"))) {
server = Some {
val args = Seq(
Paths.get(sparkHome, "sbin", "start-connect-server.sh").toString,
"--master",
remoteString.get) ++ configs.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
val pb = new ProcessBuilder(args: _*)
// So don't exclude spark-sql jar in classpath
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
pb.start()
}
// Let the server start. We will directly request to set the configurations
// and this sleep makes less noisy with retries.
Thread.sleep(2000L)
System.setProperty("spark.remote", "sc://localhost")
}

errorStream: OutputStream = System.err): Unit = withLocalConnectServer {
// Build the client.
val client =
try {
Expand All @@ -118,13 +87,6 @@ Spark session available as 'spark'.

// Build the session.
val spark = SparkSession.builder().client(client).getOrCreate()

// The configurations might not be all runtime configurations.
// Try to set them with ignoring failures for now.
configs
.filter(_._1.startsWith("spark.sql"))
.foreach { case (k, v) => Try(spark.conf.set(k, v)) }

val sparkBind = new Bind("spark", spark)

// Add the proper imports and register a [[ClassFinder]].
Expand Down Expand Up @@ -197,18 +159,12 @@ Spark session available as 'spark'.
}
}
}
try {
if (semaphore.nonEmpty) {
// Used for testing.
main.run(sparkBind, new Bind[Semaphore]("semaphore", semaphore.get))
} else {
main.run(sparkBind)
}
} finally {
if (server.isDefined) {
new ProcessBuilder(Paths.get(sparkHome, "sbin", "stop-connect-server.sh").toString)
.start()
}

if (semaphore.nonEmpty) {
// Used for testing.
main.run(sparkBind, new Bind[Semaphore]("semaphore", semaphore.get))
} else {
main.run(sparkBind)
}
}
}
Expand Down
42 changes: 29 additions & 13 deletions core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._
import scala.util.Try

import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods.{compact, render}

import org.apache.spark.{SparkConf, SparkUserAppException}
import org.apache.spark.api.python.{Py4JServer, PythonUtils}
import org.apache.spark.internal.config._
Expand All @@ -50,18 +53,21 @@ object PythonRunner {
val formattedPythonFile = formatPath(pythonFile)
val formattedPyFiles = resolvePyFiles(formatPaths(pyFiles))

val gatewayServer = new Py4JServer(sparkConf)
var gatewayServer: Option[Py4JServer] = None
if (sparkConf.getOption("spark.remote").isEmpty) {
gatewayServer = Some(new Py4JServer(sparkConf))

val thread = new Thread(() => Utils.logUncaughtExceptions { gatewayServer.start() })
thread.setName("py4j-gateway-init")
thread.setDaemon(true)
thread.start()
val thread = new Thread(() => Utils.logUncaughtExceptions { gatewayServer.get.start() })
thread.setName("py4j-gateway-init")
thread.setDaemon(true)
thread.start()

// Wait until the gateway server has started, so that we know which port is it bound to.
// `gatewayServer.start()` will start a new thread and run the server code there, after
// initializing the socket, so the thread started above will end as soon as the server is
// ready to serve connections.
thread.join()
// Wait until the gateway server has started, so that we know which port is it bound to.
// `gatewayServer.start()` will start a new thread and run the server code there, after
// initializing the socket, so the thread started above will end as soon as the server is
// ready to serve connections.
thread.join()
}

// Build up a PYTHONPATH that includes the Spark assembly (where this class is), the
// python directories in SPARK_HOME (if set), and any files in the pyFiles argument
Expand All @@ -74,12 +80,22 @@ object PythonRunner {
// Launch Python process
val builder = new ProcessBuilder((Seq(pythonExec, formattedPythonFile) ++ otherArgs).asJava)
val env = builder.environment()
if (sparkConf.getOption("spark.remote").nonEmpty) {
// For non-local remote, pass configurations to environment variables so
// Spark Connect client sets them. For local remotes, they will be set
// via Py4J.
val grouped = sparkConf.getAll.toMap.grouped(10).toSeq
env.put("PYSPARK_REMOTE_INIT_CONF_LEN", grouped.length.toString)
grouped.zipWithIndex.foreach { case (group, idx) =>
env.put(s"PYSPARK_REMOTE_INIT_CONF_$idx", compact(render(group)))
}
}
sparkConf.getOption("spark.remote").foreach(url => env.put("SPARK_REMOTE", url))
env.put("PYTHONPATH", pythonPath)
// This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string
env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort)
env.put("PYSPARK_GATEWAY_SECRET", gatewayServer.secret)
gatewayServer.foreach(s => env.put("PYSPARK_GATEWAY_PORT", s.getListeningPort.toString))
gatewayServer.foreach(s => env.put("PYSPARK_GATEWAY_SECRET", s.secret))
// pass conf spark.pyspark.python to python process, the only way to pass info to
// python process is through environment variable.
sparkConf.get(PYSPARK_PYTHON).foreach(env.put("PYSPARK_PYTHON", _))
Expand All @@ -103,7 +119,7 @@ object PythonRunner {
throw new SparkUserAppException(exitCode)
}
} finally {
gatewayServer.shutdown()
gatewayServer.foreach(_.shutdown())
}
}

Expand Down
17 changes: 17 additions & 0 deletions core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1802,6 +1802,23 @@ class SparkSubmitSuite
val (_, classpath, _, _) = submit.prepareSubmitEnvironment(appArgs)
assert(classpath.contains("."))
}

// Requires Python dependencies for Spark Connect. Should be enabled by default.
ignore("Spark Connect application submission (Python)") {
val pyFile = File.createTempFile("remote_test", ".py")
pyFile.deleteOnExit()
val content =
"from pyspark.sql import SparkSession;" +
"spark = SparkSession.builder.getOrCreate();" +
"assert 'connect' in str(type(spark));" +
"assert spark.range(1).first()[0] == 0"
FileUtils.write(pyFile, content, StandardCharsets.UTF_8)
val args = Seq(
"--name", "testPyApp",
"--remote", "local",
pyFile.getAbsolutePath)
runSparkSubmit(args)
}
}

object JarCreationTest extends Logging {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,9 @@ List<String> buildClassPath(String appClassPath) throws IOException {
addToClassPath(cp, f.toString());
}
}
if (isRemote && "1".equals(getenv("SPARK_SCALA_SHELL"))) {
// If we're in 'spark.local.connect', it should create a Spark Classic Spark Context
// that launches Spark Connect server.
if (isRemote && System.getenv("SPARK_LOCAL_CONNECT") == null) {
for (File f: new File(jarsDir).listFiles()) {
// Exclude Spark Classic SQL and Spark Connect server jars
// if we're in Spark Connect Shell. Also exclude Spark SQL API and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,6 @@ public List<String> buildCommand(Map<String, String> env)
javaOptsKeys.add("SPARK_BEELINE_OPTS");
yield "SPARK_BEELINE_MEMORY";
}
case "org.apache.spark.sql.application.ConnectRepl" -> {
isRemote = true;
yield "SPARK_DRIVER_MEMORY";
}
default -> "SPARK_DRIVER_MEMORY";
};

Expand Down
29 changes: 29 additions & 0 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

check_dependencies(__name__)

import json
import threading
import os
import warnings
Expand Down Expand Up @@ -200,6 +201,26 @@ def enableHiveSupport(self) -> "SparkSession.Builder":
)

def _apply_options(self, session: "SparkSession") -> None:
init_opts = {}
for i in range(int(os.environ.get("PYSPARK_REMOTE_INIT_CONF_LEN", "0"))):
init_opts = json.loads(os.environ[f"PYSPARK_REMOTE_INIT_CONF_{i}"])

with self._lock:
for k, v in init_opts.items():
# the options are applied after session creation,
# so following options always take no effect
if k not in [
"spark.remote",
"spark.master",
] and k.startswith("spark.sql."):
# Only attempts to set Spark SQL configurations.
# If the configurations are static, it might throw an exception so
# simply ignore it for now.
try:
session.conf.set(k, v)
except Exception:
pass

with self._lock:
for k, v in self._options.items():
# the options are applied after session creation,
Expand Down Expand Up @@ -993,10 +1014,17 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:

session = PySparkSession._instantiatedSession
if session is None or session._sc._jsc is None:
init_opts = {}
for i in range(int(os.environ.get("PYSPARK_REMOTE_INIT_CONF_LEN", "0"))):
init_opts = json.loads(os.environ[f"PYSPARK_REMOTE_INIT_CONF_{i}"])
init_opts.update(opts)
opts = init_opts

# Configurations to be overwritten
overwrite_conf = opts
overwrite_conf["spark.master"] = master
overwrite_conf["spark.local.connect"] = "1"
os.environ["SPARK_LOCAL_CONNECT"] = "1"

# Configurations to be set if unset.
default_conf = {"spark.plugins": "org.apache.spark.sql.connect.SparkConnectPlugin"}
Expand Down Expand Up @@ -1030,6 +1058,7 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
finally:
if origin_remote is not None:
os.environ["SPARK_REMOTE"] = origin_remote
del os.environ["SPARK_LOCAL_CONNECT"]
else:
raise PySparkRuntimeError(
errorClass="SESSION_OR_CONTEXT_EXISTS",
Expand Down

0 comments on commit 3e22c47

Please sign in to comment.