Skip to content

Dropout Support in NN and GBDT

Zhang Pengshan (David) edited this page Jul 2, 2017 · 5 revisions

Dropout is a good methodology to conquer over-fit in deep learning, while in our neural network (NN) and gradient boosted decision tree (GBDT), we also support dropout for users to tune a stable model with less opportunity to get overfit.

Dropout Support in NN

Dropout in NN

To enable it,

  "train" : {
    "baggingNum" : 5,
    ...
    "algorithm" : "NN",
    "params" : {
      "NumHiddenLayers" : 2,
      "ActivationFunc" : [ "Sigmoid", "Sigmoid" ],
      "NumHiddenNodes" : [ 45, 45 ],
      "LearningRate" : 0.1,
      "FeatureSubsetStrategy" : 1,
      "DropoutRate": 0.1,
      "Propagation" : "Q"
    },
  },

Just set 'DropoutRate' in NN training parameters. Here 'DropoutRate' means the probability to dropout gradient update in each epoch. If you set it to 0.1 which means 10% weights will not be updated in one epoch and each epoch such 10% weights are selected by random. While the first tree will never be dropped out since it is a initial model in Gradient boosted machine.

In deep learning, 0.5 is a good value for most cases while in our shallow neural network training, I would suggest 0.1-0.3, it is tested better with some improvement on model performance.

Dropout Support in GBDT

To enable it (like in NN),

  "train" : {
    "baggingNum" : 5,
    ...
    "algorithm" : "GBT",
    "params" : {
     "TreeNum" : 1000,
      "FeatureSubsetStrategy" : "ONETHIRD",
      "MaxDepth" : 6,
      "Impurity" : "variance",
      "LearningRate" : 0.02,
      "MinInstancesPerNode" : 5,
      "MinInfoGain" : 0.0,
      "DropoutRate": 0.1,
      "Loss" : "squared"
    },
  },

'DropoutRate' in GBDT means to drop out trees according to the probability value. If set it to 0.1 means in each new iteration, 10% of tree values will not be computed for current residual computing. In testing 0.05-0.1 has some improvement on model performance.

For more details about dropout in GBDT, please check this paper: Dropout in GBDT.

Clone this wiki locally