Skip to content

Commit

Permalink
added L1 regularization and L2 regularization
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft authored and mhamilton723 committed Apr 23, 2019
1 parent cfead9a commit e59b234
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 7 deletions.
3 changes: 2 additions & 1 deletion src/lightgbm/src/main/scala/LightGBMClassifier.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ class LightGBMClassifier(override val uid: String)
ClassifierTrainParams(getParallelism, getNumIterations, getLearningRate, getNumLeaves,
getMaxBin, getBaggingFraction, getBaggingFreq, getBaggingSeed, getEarlyStoppingRound,
getFeatureFraction, getMaxDepth, getMinSumHessianInLeaf, numWorkers, getObjective, modelStr,
getIsUnbalance, getVerbosity, categoricalIndexes, actualNumClasses, metric, getBoostFromAverage, getBoostingType)
getIsUnbalance, getVerbosity, categoricalIndexes, actualNumClasses, metric, getBoostFromAverage,
getBoostingType, getLambdaL1, getLambdaL2)
}

def getModel(trainParams: TrainParams, lightGBMBooster: LightGBMBooster): LightGBMClassificationModel = {
Expand Down
12 changes: 12 additions & 0 deletions src/lightgbm/src/main/scala/LightGBMParams.scala
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,16 @@ trait LightGBMParams extends Wrappable with DefaultParamsWritable with HasWeight

def getBoostingType: String = $(boostingType)
def setBoostingType(value: String): this.type = set(boostingType, value)

val lambdaL1 = new DoubleParam(this, "lambdaL1", "L1 regularization")
setDefault(lambdaL1 -> 0.0)

def getLambdaL1: Double = $(lambdaL1)
def setLambdaL1(value: Double): this.type = set(lambdaL1, value)

val lambdaL2 = new DoubleParam(this, "lambdaL2", "L2 regularization")
setDefault(lambdaL2 -> 0.0)

def getLambdaL2: Double = $(lambdaL2)
def setLambdaL2(value: Double): this.type = set(lambdaL2, value)
}
2 changes: 1 addition & 1 deletion src/lightgbm/src/main/scala/LightGBMRegressor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class LightGBMRegressor(override val uid: String)
RegressorTrainParams(getParallelism, getNumIterations, getLearningRate, getNumLeaves,
getObjective, getAlpha, getTweedieVariancePower, getMaxBin, getBaggingFraction, getBaggingFreq, getBaggingSeed,
getEarlyStoppingRound, getFeatureFraction, getMaxDepth, getMinSumHessianInLeaf, numWorkers, modelStr,
getVerbosity, categoricalIndexes, getBoostFromAverage, getBoostingType)
getVerbosity, categoricalIndexes, getBoostFromAverage, getBoostingType, getLambdaL1, getLambdaL2)
}

def getModel(trainParams: TrainParams, lightGBMBooster: LightGBMBooster): LightGBMRegressionModel = {
Expand Down
8 changes: 5 additions & 3 deletions src/lightgbm/src/main/scala/TrainParams.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ abstract class TrainParams extends Serializable {
def categoricalFeatures: Array[Int]
def boostFromAverage: Boolean
def boostingType: String
def lambdaL1: Double
def lambdaL2: Double

override def toString(): String = {
s"is_pre_partition=True boosting_type=$boostingType tree_learner=$parallelism num_iterations=$numIterations " +
Expand All @@ -33,7 +35,7 @@ abstract class TrainParams extends Serializable {
s"bagging_seed=$baggingSeed early_stopping_round=$earlyStoppingRound " +
s"feature_fraction=$featureFraction max_depth=$maxDepth min_sum_hessian_in_leaf=$minSumHessianInLeaf " +
s"num_machines=$numMachines objective=$objective verbosity=$verbosity " +
s"boost_from_average=${boostFromAverage.toString} " +
s"boost_from_average=${boostFromAverage.toString} lambda_l1=$lambdaL1 lambda_l2=$lambdaL2 " +
(if (categoricalFeatures.isEmpty) "" else s"categorical_feature=${categoricalFeatures.mkString(",")}")
}
}
Expand All @@ -47,7 +49,7 @@ case class ClassifierTrainParams(val parallelism: String, val numIterations: Int
val numMachines: Int, val objective: String, val modelString: Option[String],
val isUnbalance: Boolean, val verbosity: Int, val categoricalFeatures: Array[Int],
val numClass: Int, val metric: String, val boostFromAverage: Boolean,
val boostingType: String)
val boostingType: String, val lambdaL1: Double, val lambdaL2: Double)
extends TrainParams {
override def toString(): String = {
val extraStr =
Expand All @@ -67,7 +69,7 @@ case class RegressorTrainParams(val parallelism: String, val numIterations: Int,
val maxDepth: Int, val minSumHessianInLeaf: Double, val numMachines: Int,
val modelString: Option[String], val verbosity: Int,
val categoricalFeatures: Array[Int], val boostFromAverage: Boolean,
val boostingType: String)
val boostingType: String, val lambdaL1: Double, val lambdaL2: Double)
extends TrainParams {
override def toString(): String = {
s"alpha=$alpha tweedie_variance_power=$tweedieVariancePower ${super.toString}"
Expand Down
22 changes: 20 additions & 2 deletions src/lightgbm/src/main/scala/TrainUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,23 @@ private object TrainUtils extends Serializable {
}.get
}

def networkInit(nodes: String, localListenPort: Int, log: Logger, retry: Int, delay: Long): Unit = {
try {
LightGBMUtils.validate(lightgbmlib.LGBM_NetworkInit(nodes, localListenPort,
LightGBMConstants.defaultListenTimeout, nodes.split(",").length), "Network init")
} catch {
case ex: Throwable => {
log.info(s"NetworkInit failed with exception on local port $localListenPort, retrying: $ex")
Thread.sleep(delay)
if (retry > 0) {
networkInit(nodes, localListenPort, log, retry - 1, delay * 2)
} else {
throw ex
}
}
}
}

def trainLightGBM(networkParams: NetworkParams, labelColumn: String, featuresColumn: String,
weightColumn: Option[String], validationData: Option[Broadcast[Array[Row]]], log: Logger,
trainParams: TrainParams, numCoresPerExec: Int)
Expand All @@ -293,8 +310,9 @@ private object TrainUtils extends Serializable {
// Initialize the network communication
log.info(s"LightGBM worker listening on: $localListenPort")
try {
LightGBMUtils.validate(lightgbmlib.LGBM_NetworkInit(nodes, localListenPort,
LightGBMConstants.defaultListenTimeout, nodes.split(",").length), "Network init")
val retries = 3
val initialDelay = 1000L
networkInit(nodes, localListenPort, log, retries, initialDelay)
translate(labelColumn, featuresColumn, weightColumn, validationData, log, trainParams, inputRows)
} finally {
// Finalize network when done
Expand Down
6 changes: 6 additions & 0 deletions src/lightgbm/src/test/scala/VerifyLightGBMClassifier.scala
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM
val paramGrid = new ParamGridBuilder()
.addGrid(lgbm.numLeaves, Array(5, 10))
.addGrid(lgbm.numIterations, Array(10, 20))
.addGrid(lgbm.lambdaL1, Array(0.1, 0.5))
.addGrid(lgbm.lambdaL2, Array(0.1, 0.5))
.build()

val trainValidationSplit = new TrainValidationSplit()
Expand All @@ -82,6 +84,10 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM
val model = trainValidationSplit.fit(featurizer.transform(dataset))
model.transform(featurizer.transform(dataset))
assert(model != null)
// Validate lambda parameters set on model
val modelStr = model.bestModel.asInstanceOf[LightGBMClassificationModel].getModel.model
assert(modelStr.contains("[lambda_l1: 0.1]") || modelStr.contains("[lambda_l1: 0.5]"))
assert(modelStr.contains("[lambda_l2: 0.1]") || modelStr.contains("[lambda_l2: 0.5]"))
}

test("Verify LightGBM Classifier with weight column") {
Expand Down
6 changes: 6 additions & 0 deletions src/lightgbm/src/test/scala/VerifyLightGBMRegressor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class VerifyLightGBMRegressor extends Benchmarks with EstimatorFuzzing[LightGBMR
val paramGrid = new ParamGridBuilder()
.addGrid(lgbm.numLeaves, Array(5, 10))
.addGrid(lgbm.numIterations, Array(10, 20))
.addGrid(lgbm.lambdaL1, Array(0.1, 0.5))
.addGrid(lgbm.lambdaL2, Array(0.1, 0.5))
.build()

val trainValidationSplit = new TrainValidationSplit()
Expand All @@ -63,6 +65,10 @@ class VerifyLightGBMRegressor extends Benchmarks with EstimatorFuzzing[LightGBMR
val model = trainValidationSplit.fit(featurizer.transform(dataset))
model.transform(featurizer.transform(dataset))
assert(model != null)
// Validate lambda parameters set on model
val modelStr = model.bestModel.asInstanceOf[LightGBMRegressionModel].getModel.model
assert(modelStr.contains("[lambda_l1: 0.1]") || modelStr.contains("[lambda_l1: 0.5]"))
assert(modelStr.contains("[lambda_l2: 0.1]") || modelStr.contains("[lambda_l2: 0.5]"))
}

test("Verify LightGBM Regressor with weight column") {
Expand Down

0 comments on commit e59b234

Please sign in to comment.