From 2f508ea1cb4fd37db0d69db2d943c5c526a3c446 Mon Sep 17 00:00:00 2001 From: xianghui Date: Tue, 11 Jun 2024 18:57:05 +0800 Subject: [PATCH] fix bind port failed on spark yarn environment --- .../ml/lightgbm/BasePartitionTask.scala | 7 ++++ .../synapse/ml/lightgbm/NetworkManager.scala | 34 +++++++++++-------- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/BasePartitionTask.scala b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/BasePartitionTask.scala index 6dccaa84f6..cca2283c45 100644 --- a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/BasePartitionTask.scala +++ b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/BasePartitionTask.scala @@ -140,6 +140,13 @@ abstract class BasePartitionTask extends Serializable with Logging { try { if (taskCtx.shouldExecuteTraining) { + //close socket before lightgbm bind port + try { + taskCtx.networkTopologyInfo.localSocket.close() + } catch { + case e: Exception => log.warn("close local bind port socket failed ") + } + // If participating in training, initialize the network ring of communication NetworkManager.initLightGBMNetwork(taskCtx, log) diff --git a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/NetworkManager.scala b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/NetworkManager.scala index 89aef07a50..c019540be5 100644 --- a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/NetworkManager.scala +++ b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/NetworkManager.scala @@ -39,7 +39,8 @@ case class TaskMessageInfo(status: String, case class NetworkTopologyInfo(lightgbmNetworkString: String, executorPartitionIdList: Array[Int], - localListenPort: Int) + localListenPort: Int, + localSocket: Socket) object NetworkManager { /** @@ -107,19 +108,21 @@ object NetworkManager { measures: TaskInstrumentationMeasures): NetworkTopologyInfo = { measures.markNetworkInitializationStart() val networkParams = ctx.networkParams - val out = using(findOpenPort(ctx, log).get) { - openPort => - val localListenPort = openPort.getLocalPort - log.info(s"LightGBM task $taskId connecting to host: ${networkParams.ipAddress}, port: ${networkParams.port}") - FaultToleranceUtils.retryWithTimeout() { - getNetworkTopologyInfoFromDriver(networkParams, - taskId, - partitionId, - localListenPort, - log, - shouldExecuteTraining) - } - }.get + val localSocket = findOpenPort(ctx, log).get + //hold socket until lightgbm bind port + val out = ((localSocket: Socket) => { + val localListenPort = localSocket.getLocalPort + log.info(s"LightGBM task $taskId connecting to host: ${networkParams.ipAddress}, port: ${networkParams.port}") + FaultToleranceUtils.retryWithTimeout() { + getNetworkTopologyInfoFromDriver(networkParams, + taskId, + partitionId, + localListenPort, + localSocket, + log, + shouldExecuteTraining) + } + })(localSocket) measures.markNetworkInitializationStop() out } @@ -128,6 +131,7 @@ object NetworkManager { taskId: Long, partitionId: Int, localListenPort: Int, + localSocket: Socket, log: Logger, shouldExecuteTraining: Boolean): NetworkTopologyInfo = { using(new Socket(networkParams.ipAddress, networkParams.port)) { @@ -173,7 +177,7 @@ object NetworkManager { log.info(s"task $taskId, partition $partitionId received nodes for network init: '$lightGbmMachineList'") val executorPartitionIds: Array[Int] = parseExecutorPartitionList(partitionsByExecutorStr, taskStatus.executorId, log) - NetworkTopologyInfo(lightGbmMachineList, executorPartitionIds, localListenPort) + NetworkTopologyInfo(lightGbmMachineList, executorPartitionIds, localListenPort,localSocket) }.get }.get }