Skip to content

Commit

Permalink
Initial version
Browse files Browse the repository at this point in the history
  • Loading branch information
roczei committed Oct 4, 2024
1 parent 38f067d commit 09132c4
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 1 deletion.
16 changes: 16 additions & 0 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ class SparkContext(config: SparkConf) extends Logging {
private var _shutdownHookRef: AnyRef = _
private var _statusStore: AppStatusStore = _
private var _heartbeater: Heartbeater = _
private var _driverThreadDumpCollector: ThreadDumpCollector = _
private var _resources: immutable.Map[String, ResourceInformation] = _
private var _shuffleDriverComponents: ShuffleDriverComponents = _
private var _plugins: Option[PluginContainer] = None
Expand Down Expand Up @@ -613,6 +614,15 @@ class SparkContext(config: SparkConf) extends Logging {
conf.get(EXECUTOR_HEARTBEAT_INTERVAL))
_heartbeater.start()

// Create and start the thread dump collector for the Spark driver
if (_conf.get(DRIVER_THREAD_DUMP_COLLECTOR_ENABLED)) {
_driverThreadDumpCollector = new ThreadDumpCollector(
() => Utils.writeThreadDumpToFile(env),
"driver-threadDumpCollector",
conf.get(THREAD_DUMP_COLLECTOR_INTERVAL))
_driverThreadDumpCollector.start()
}

// start TaskScheduler after taskScheduler sets DAGScheduler reference in DAGScheduler's
// constructor
_taskScheduler.start()
Expand Down Expand Up @@ -2375,6 +2385,12 @@ class SparkContext(config: SparkConf) extends Logging {
}
_heartbeater = null
}
if (_conf.get(DRIVER_THREAD_DUMP_COLLECTOR_ENABLED) && _driverThreadDumpCollector != null) {
Utils.tryLogNonFatalError {
_driverThreadDumpCollector.stop()
}
_driverThreadDumpCollector = null
}
if (env != null && _heartbeatReceiver != null) {
Utils.tryLogNonFatalError {
env.rpcEnv.stop(_heartbeatReceiver)
Expand Down
54 changes: 54 additions & 0 deletions core/src/main/scala/org/apache/spark/ThreadDumpCollector.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark

import java.util.concurrent.TimeUnit

import org.apache.spark.internal.Logging
import org.apache.spark.util.{ThreadUtils, Utils}


/**
* Creates a strack trace collector thread which will call the specified collectStackTraces
* function at intervals of intervalMs.
*
* @param collectThreadDumps the thread dump collector function to call.
* @param name the thread name for the thread dump collector.
* @param intervalMs the interval between stack trace collections.
*/
private[spark] class ThreadDumpCollector(
collectThreadDumps: () => Unit,
name: String,
intervalMs: Long) extends Logging {
// Executor for the stack collector task
private val threadDumpCollector = ThreadUtils.newDaemonSingleThreadScheduledExecutor(name)

/** Schedules a task to collect the stack traces */
def start(): Unit = {
val threadDumpCollectorTask = new Runnable() {
override def run(): Unit = Utils.logUncaughtExceptions(collectThreadDumps())
}
threadDumpCollector.scheduleAtFixedRate(threadDumpCollectorTask, intervalMs, intervalMs,
TimeUnit.MILLISECONDS)
}

def stop(): Unit = {
threadDumpCollector.shutdown()
threadDumpCollector.awaitTermination(10, TimeUnit.SECONDS)
}
}
20 changes: 20 additions & 0 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ private[spark] class Executor(
log"${LogMDC(OS_ARCH, System.getProperty("os.arch"))}")
logInfo(log"Java version ${LogMDC(JAVA_VERSION, System.getProperty("java.version"))}")

private var executorThreadDumpCollector: ThreadDumpCollector = _

private val executorShutdown = new AtomicBoolean(false)
val stopHookReference = ShutdownHookManager.addShutdownHook(
() => stop()
Expand Down Expand Up @@ -325,6 +327,15 @@ private[spark] class Executor(

heartbeater.start()

// Create and start the thread dump collector for the Spark executor
if (conf.get(EXECUTOR_THREAD_DUMP_COLLECTOR_ENABLED)) {
executorThreadDumpCollector = new ThreadDumpCollector(
() => Utils.writeThreadDumpToFile(env),
"executor-ThreadDumpCollector",
conf.get(THREAD_DUMP_COLLECTOR_INTERVAL))
executorThreadDumpCollector.start()
}

private val appStartTime = conf.getLong("spark.app.startTime", 0)

// To allow users to distribute plugins and their required files
Expand Down Expand Up @@ -445,6 +456,15 @@ private[spark] class Executor(
case NonFatal(e) =>
logWarning("Unable to stop heartbeater", e)
}
try {
if (conf.get(EXECUTOR_THREAD_DUMP_COLLECTOR_ENABLED) && executorThreadDumpCollector !=
null) {
executorThreadDumpCollector.stop()
}
} catch {
case NonFatal(e) =>
logWarning("Unable to stop the executor thread dump collector", e)
}
ShuffleBlockPusher.stop()
if (threadPool != null) {
threadPool.shutdown()
Expand Down
23 changes: 23 additions & 0 deletions core/src/main/scala/org/apache/spark/internal/config/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2801,4 +2801,27 @@ package object config {
.version("4.0.0")
.timeConf(TimeUnit.MILLISECONDS)
.createOptional

private[spark] val DRIVER_THREAD_DUMP_COLLECTOR_ENABLED = ConfigBuilder("spark.driver" +
".threadDumpCollector.enabled")
.version("4.0.0")
.booleanConf
.createWithDefault(false)

private[spark] val EXECUTOR_THREAD_DUMP_COLLECTOR_ENABLED = ConfigBuilder("spark.executor" +
".threadDumpCollector.enabled")
.version("4.0.0")
.booleanConf
.createWithDefault(false)

private[spark] val THREAD_DUMP_COLLECTOR_INTERVAL =
ConfigBuilder("spark.threadDumpCollectorInterval")
.version("4.0.0")
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefaultString("10s")

private[spark] val THREAD_DUMP_COLLECTOR_DIR = ConfigBuilder("spark.threadDumpCollector.dir")
.version("4.0.0")
.stringConf
.createWithDefault("/tmp/spark-thread-dumps")
}
29 changes: 28 additions & 1 deletion core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import java.nio.channels.Channels
import java.nio.charset.StandardCharsets
import java.nio.file.Files
import java.security.SecureRandom
import java.time.Instant
import java.util.{Locale, Properties, Random, UUID}
import java.util.concurrent._
import java.util.concurrent.TimeUnit.NANOSECONDS
Expand All @@ -53,7 +54,8 @@ import org.apache.commons.codec.binary.Hex
import org.apache.commons.io.IOUtils
import org.apache.commons.lang3.{JavaVersion, SystemUtils}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, FileUtil, Path}
import org.apache.hadoop.fs.{FileSystem, FileUtil, FSDataOutputStream, Path}
import org.apache.hadoop.fs.permission.FsPermission
import org.apache.hadoop.io.compress.{CompressionCodecFactory, SplittableCompressionCodec}
import org.apache.hadoop.ipc.{CallerContext => HadoopCallerContext}
import org.apache.hadoop.ipc.CallerContext.{Builder => HadoopCallerContextBuilder}
Expand Down Expand Up @@ -2112,6 +2114,31 @@ private[spark] object Utils
}.map(threadInfoToThreadStackTrace)
}

def writeThreadDumpToFile(env: SparkEnv): Unit = {
val collectedThreadDump = getThreadDump().map(_.toString).mkString
val hadoopConf = SparkHadoopUtil.get.newConfiguration(env.conf)
val rootDir = env.conf.get(THREAD_DUMP_COLLECTOR_DIR)
var outputStream: FSDataOutputStream = null
val fileSystem: FileSystem = new Path(rootDir).getFileSystem(hadoopConf)
val threadDumpFilePermissions = new FsPermission(Integer.parseInt("770", 8).toShort)
val timestamp = Instant.now().getEpochSecond().toString()
val threadDumpFileName = env.conf.getAppId + "-" + env.executorId + "-" + timestamp + ".txt"
val dfsLogFile: Path = fileSystem.makeQualified(new Path(rootDir, threadDumpFileName))
try {
outputStream = SparkHadoopUtil.createFile(fileSystem, dfsLogFile, true)
fileSystem.setPermission(dfsLogFile, threadDumpFilePermissions)
outputStream.write(collectedThreadDump.getBytes(StandardCharsets.UTF_8))
outputStream.close()
} catch {
case e: Exception =>
logError(
log"Could not collect thread dump from executor ${
MDC(LogKeys.EXECUTOR_ID,
env.executorId)
}", e)
}
}

/** Return a heap dump. Used to capture dumps for the web UI */
def getHeapHistogram(): Array[String] = {
val pid = String.valueOf(ProcessHandle.current().pid())
Expand Down

0 comments on commit 09132c4

Please sign in to comment.