Skip to content

Commit

Permalink
[SPARK-18280][CORE] Fix potential deadlock in `StandaloneSchedulerBac…
Browse files Browse the repository at this point in the history
…kend.dead`

## What changes were proposed in this pull request?

"StandaloneSchedulerBackend.dead" is called in a RPC thread, so it should not call "SparkContext.stop" in the same thread. "SparkContext.stop" will block until all RPC threads exit, if it's called inside a RPC thread, it will be dead-lock.

This PR add a thread local flag inside RPC threads. `SparkContext.stop` uses it to decide if launching a new thread to stop the SparkContext.

## How was this patch tested?

Jenkins

Author: Shixiong Zhu <[email protected]>

Closes apache#15775 from zsxwing/SPARK-18280.
  • Loading branch information
zsxwing committed Nov 8, 2016
1 parent f441b9a commit 8aa419b
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 2 deletions.
22 changes: 20 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1734,8 +1734,26 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
def listJars(): Seq[String] = addedJars.keySet.toSeq

// Shut down the SparkContext.
def stop() {
/**
* Shut down the SparkContext.
*/
def stop(): Unit = {
if (env.rpcEnv.isInRPCThread) {
// `stop` will block until all RPC threads exit, so we cannot call stop inside a RPC thread.
// We should launch a new thread to call `stop` to avoid dead-lock.
new Thread("stop-spark-context") {
setDaemon(true)

override def run(): Unit = {
_stop()
}
}.start()
} else {
_stop()
}
}

private def _stop() {
if (LiveListenerBus.withinListenerThread.value) {
throw new SparkException(
s"Cannot stop SparkContext within listener thread of ${LiveListenerBus.name}")
Expand Down
4 changes: 4 additions & 0 deletions core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ private[spark] abstract class RpcEnv(conf: SparkConf) {
*/
def openChannel(uri: String): ReadableByteChannel

/**
* Return if the current thread is a RPC thread.
*/
def isInRPCThread: Boolean
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
/** Message loop used for dispatching messages. */
private class MessageLoop extends Runnable {
override def run(): Unit = {
NettyRpcEnv.rpcThreadFlag.value = true
try {
while (true) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,10 +408,13 @@ private[netty] class NettyRpcEnv(

}

override def isInRPCThread: Boolean = NettyRpcEnv.rpcThreadFlag.value
}

private[netty] object NettyRpcEnv extends Logging {

private[netty] val rpcThreadFlag = new DynamicVariable[Boolean](false)

/**
* When deserializing the [[NettyRpcEndpointRef]], it needs a reference to [[NettyRpcEnv]].
* Use `currentEnv` to wrap the deserialization codes. E.g.,
Expand Down
13 changes: 13 additions & 0 deletions core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,19 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
verify(endpoint, never()).onDisconnected(any())
verify(endpoint, never()).onNetworkError(any(), any())
}

test("isInRPCThread") {
val rpcEndpointRef = env.setupEndpoint("isInRPCThread", new RpcEndpoint {
override val rpcEnv = env

override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case m => context.reply(rpcEnv.isInRPCThread)
}
})
assert(rpcEndpointRef.askWithRetry[Boolean]("hello") === true)
assert(env.isInRPCThread === false)
env.stop(rpcEndpointRef)
}
}

class UnserializableClass
Expand Down

0 comments on commit 8aa419b

Please sign in to comment.