From 212164bbecdc41c7b290314bc17c12aef3e52ffa Mon Sep 17 00:00:00 2001 From: Scott Votaw Date: Mon, 16 Oct 2023 13:37:58 -0700 Subject: [PATCH] group id fixes --- .../synapse/ml/lightgbm/GroupIdManager.scala | 45 +++++++++++++++++++ .../synapse/ml/lightgbm/SharedState.scala | 2 + .../ml/lightgbm/StreamingPartitionTask.scala | 5 ++- .../split2/LightGBMRankerTestData.scala | 1 + .../split2/LightGBMRegressorTestData.scala | 1 + .../split2/VerifyLightGBMRankerStream.scala | 7 +++ 6 files changed, 60 insertions(+), 1 deletion(-) create mode 100644 lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/GroupIdManager.scala diff --git a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/GroupIdManager.scala b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/GroupIdManager.scala new file mode 100644 index 0000000000..e6044b833b --- /dev/null +++ b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/GroupIdManager.scala @@ -0,0 +1,45 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.lightgbm + +import scala.collection.mutable +import scala.language.existentials + +/** Class for converting column values to group ID. + * + * Ints can just be returned, but a map of Long and String values is maintains so that unique and + * consistent values can be returned. + */ +class GroupIdManager { + private val stringGroupIds = mutable.Map[String, Int]() + private val longGroupIds = mutable.Map[Long, Int]() + private[this] val lock = new Object() + + /** Convert a group ID into a unique Int. + * + * @param groupValue The original group ID value + */ + def getUniqueIdForGroup(groupValue: Any): Int = { + groupValue match { + case iVal: Int => + iVal // If it's already an Int, just return + case lVal: Long => + lock.synchronized { + if (!longGroupIds.contains(lVal)) { + longGroupIds(lVal) = longGroupIds.size + } + longGroupIds(lVal) + } + case sVal: String => + lock.synchronized { + if (!stringGroupIds.contains(sVal)) { + stringGroupIds(sVal) = longGroupIds.size + } + stringGroupIds(sVal) + } + case _ => + throw new IllegalArgumentException(s"Unsupported group column type: '${groupValue.getClass.getName}'") + } + } +} diff --git a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/SharedState.scala b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/SharedState.scala index c9f97923a2..2ade1ebfaa 100644 --- a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/SharedState.scala +++ b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/SharedState.scala @@ -36,6 +36,8 @@ class SharedState(trainParams: BaseTrainParams) { val datasetState: SharedDatasetState = new SharedDatasetState(trainParams, isForValidation = false) val validationDatasetState: SharedDatasetState = new SharedDatasetState(trainParams, isForValidation = true) + lazy val groupIdManager: GroupIdManager = new GroupIdManager() + @volatile var isSparse: Option[Boolean] = None @volatile var mainExecutorWorker: Option[Long] = None @volatile var validationDatasetWorker: Option[Long] = None diff --git a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/StreamingPartitionTask.scala b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/StreamingPartitionTask.scala index 98cac95f51..0d39328b77 100644 --- a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/StreamingPartitionTask.scala +++ b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/StreamingPartitionTask.scala @@ -331,7 +331,10 @@ class StreamingPartitionTask extends BasePartitionTask { private def loadOneMetadataRow(state: StreamingState, row: Row, index: Int): Unit = { state.labelBuffer.setItem(index, row.getDouble(state.labelIndex).toFloat) if (state.hasWeights) state.weightBuffer.setItem(index, row.getDouble(state.weightIndex).toFloat) - if (state.hasGroups) state.groupBuffer.setItem(index, row.getAs[Int](state.groupIndex)) + if (state.hasGroups) { + val groupIdManager = state.ctx.sharedState.groupIdManager + state.groupBuffer.setItem(index, groupIdManager.getUniqueIdForGroup(row.getAs[Any](state.groupIndex))) + } // Initial scores are passed in column-based format, where the score for each class is contiguous if (state.hasInitialScores) { diff --git a/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split2/LightGBMRankerTestData.scala b/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split2/LightGBMRankerTestData.scala index 1e9ebfc91b..e66d8b3f8c 100644 --- a/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split2/LightGBMRankerTestData.scala +++ b/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split2/LightGBMRankerTestData.scala @@ -51,6 +51,7 @@ abstract class LightGBMRankerTestData extends Benchmarks with EstimatorFuzzing[L .setGroupCol(queryCol) .setDefaultListenPort(getAndIncrementPort()) .setRepartitionByGroupingColumn(false) + .setDataTransferMode(dataTransferMode) .setNumLeaves(5) .setNumIterations(10) } diff --git a/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split2/LightGBMRegressorTestData.scala b/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split2/LightGBMRegressorTestData.scala index e6e58b0dc6..be28beccf0 100644 --- a/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split2/LightGBMRegressorTestData.scala +++ b/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split2/LightGBMRegressorTestData.scala @@ -29,6 +29,7 @@ abstract class LightGBMRegressorTestData extends Benchmarks .setLabelCol(labelCol) .setFeaturesCol(featuresCol) .setDefaultListenPort(getAndIncrementPort()) + .setDataTransferMode(dataTransferMode) .setNumLeaves(5) .setNumIterations(10) } diff --git a/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split2/VerifyLightGBMRankerStream.scala b/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split2/VerifyLightGBMRankerStream.scala index 526fae8aa4..c649bb7763 100644 --- a/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split2/VerifyLightGBMRankerStream.scala +++ b/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split2/VerifyLightGBMRankerStream.scala @@ -3,6 +3,7 @@ package com.microsoft.azure.synapse.ml.lightgbm.split2 +import com.microsoft.azure.synapse.ml.lightgbm.LightGBMConstants import com.microsoft.azure.synapse.ml.lightgbm.dataset.DatasetUtils.countCardinality import org.apache.spark.ml.feature.VectorAssembler import org.apache.spark.ml.linalg.Vectors @@ -14,6 +15,7 @@ import scala.language.postfixOps //scalastyle:off magic.number /** Tests to validate the functionality of LightGBM Ranker module in streaming mode. */ class VerifyLightGBMRankerStream extends LightGBMRankerTestData { + override val dataTransferMode: String = LightGBMConstants.StreamingDataTransferMode import spark.implicits._ @@ -76,6 +78,11 @@ class VerifyLightGBMRankerStream extends LightGBMRankerTestData { assert(counts === Seq(2, 3, 1)) } + test("verify cardinality counts: long" + executionModeSuffix) { + val counts = countCardinality(Seq(1L, 1L, 2L, 2L, 2L, 3L)) + assert(counts === Seq(2, 3, 1)) + } + test("verify cardinality counts: string" + executionModeSuffix) { val counts = countCardinality(Seq("a", "a", "b", "b", "b", "c")) assert(counts === Seq(2, 3, 1))