Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-6723 - AdaBoost #15639

Merged
merged 34 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
60382e6
GH-6723 - First version of Adaboost with hardcoded weaklearner to DRF…
valenad1 Jun 21, 2023
11a0e26
WIP add glm learner
valenad1 Jul 26, 2023
2cc9785
Add probabilities by Obtaining Calibrated Probabilities from Boosting…
valenad1 Jul 28, 2023
87c13d9
try to fix a test
valenad1 Aug 18, 2023
6fa3b5d
try to fix test
valenad1 Aug 25, 2023
89f2d44
Try to fix the tests
valenad1 Aug 28, 2023
d23e46c
try to fix the tests - this works - pass weights externally each time
valenad1 Aug 29, 2023
4c768fb
Add weights inside of the algorithm - this is working
valenad1 Aug 29, 2023
40e6285
fix for GLM
valenad1 Aug 30, 2023
0fd0a43
GH-6723 - add unit test to inner tasks
valenad1 Aug 30, 2023
109b348
use test files again in the large tests
valenad1 Aug 31, 2023
f994b1e
Improve basic training test to look into structure of weak learners
valenad1 Sep 5, 2023
3c07b25
clenup - remove isAdaBoost
valenad1 Sep 5, 2023
98191ba
add simple model summary
valenad1 Sep 5, 2023
589033d
fix java api
valenad1 Sep 6, 2023
b413535
Implement possibility to have a custom weights column and ensure that…
valenad1 Sep 14, 2023
6077424
Add categorical test
valenad1 Sep 14, 2023
7a65eb7
Cleanup that didn't change tests results
valenad1 Sep 15, 2023
e0d4845
Remove toCSV since there is API already and commented code
valenad1 Sep 15, 2023
f84d9cb
Refactor learning rate to learn rate
valenad1 Sep 15, 2023
e3a0309
Add documentation and validation to parameters
valenad1 Sep 15, 2023
5345b9b
Add documentation to AdaBoost class
valenad1 Sep 15, 2023
1840cab
add log
valenad1 Sep 15, 2023
0485715
Refactor AdaBoost - simple refactor
valenad1 Sep 18, 2023
4810d4f
Fix GLM as a weak learner
valenad1 Sep 18, 2023
40bca78
Add GBM as a weak learner
valenad1 Sep 18, 2023
32d96b2
test cleanup
valenad1 Sep 19, 2023
8574bd2
Refactor n_estimators to nlearners
valenad1 Sep 19, 2023
bf7a435
fixup! Implement possibility to have a custom weights column and ensu…
valenad1 Sep 22, 2023
445dc39
Fix for different model as a weak learner - use upperclass instead of…
valenad1 Sep 22, 2023
4ec9229
Ensure that adaboost create exactly nlearners models
valenad1 Sep 22, 2023
e52c287
Refactoring according to suggestions
valenad1 Sep 22, 2023
aeaa11b
fixup! Fix for different model as a weak learner - use upperclass ins…
valenad1 Sep 22, 2023
5b2b780
GH-6723 AdaBoost API (#15732)
valenad1 Sep 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 242 additions & 0 deletions h2o-algos/src/main/java/hex/adaboost/AdaBoost.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
package hex.adaboost;

import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.glm.GLM;
import hex.glm.GLMModel;
import hex.tree.drf.DRF;
import hex.tree.drf.DRFModel;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import org.apache.log4j.Logger;
import water.*;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.Timer;
import water.util.TwoDimTable;

import java.util.ArrayList;
import java.util.List;

/**
* Implementation of AdaBoost algorithm based on
*
* Raul Rojas, "Adaboost and the Super Bowl of Classifiers A Tutorial Introduction to Adaptive Boosting"
* Alexandru Niculescu-Mizil and Richard A. Caruana, "Obtaining Calibrated Probabilities from Boosting"
* Y. Freund, R. Schapire, “A Decision-Theoretic Generalization of on-Line Learning and an Application to Boosting”, 1995.
*
* @author Adam Valenta
*/
public class AdaBoost extends ModelBuilder<AdaBoostModel, AdaBoostModel.AdaBoostParameters, AdaBoostModel.AdaBoostOutput> {
private static final Logger LOG = Logger.getLogger(AdaBoost.class);
private static final int MAX_LEARNERS = 100_000;

private AdaBoostModel _model;
private String _weightsName = "weights";

// Called from an http request
public AdaBoost(AdaBoostModel.AdaBoostParameters parms) {
super(parms);
init(false);
}

public AdaBoost(boolean startup_once) {
super(new AdaBoostModel.AdaBoostParameters(), startup_once);
}

@Override
public boolean havePojo() {
return false;
}

@Override
public boolean haveMojo() {
return false;
}

@Override
public void init(boolean expensive) {
super.init(expensive);
if(_parms._nlearners < 1 || _parms._nlearners > MAX_LEARNERS)
error("n_estimators", "Parameter n_estimators must be in interval [1, "
+ MAX_LEARNERS + "] but it is " + _parms._nlearners);
if (_parms._weak_learner == AdaBoostModel.Algorithm.AUTO) {
_parms._weak_learner = AdaBoostModel.Algorithm.DRF;
}
if (_parms._weights_column != null) {
// _parms._weights_column cannot be used all time since it breaks scoring
_weightsName = _parms._weights_column;
}
if( !(0. < _parms._learn_rate && _parms._learn_rate <= 1.0) ) {
error("learn_rate", "learn_rate must be between 0 and 1");
}
}

private class AdaBoostDriver extends Driver {

@Override
public void computeImpl() {
_model = null;
try {
init(true);
if (error_count() > 0) {
throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(AdaBoost.this);
}
_model = new AdaBoostModel(dest(), _parms,
new AdaBoostModel.AdaBoostOutput(AdaBoost.this));
_model.delete_and_lock(_job);
buildAdaboost();
LOG.info(_model.toString());
} finally {
if (_model != null)
_model.unlock(_job);
}
}

private void buildAdaboost() {
_model._output.alphas = new double[(int)_parms._nlearners];
_model._output.models = new Key[(int)_parms._nlearners];

Frame _trainWithWeights;
if (_parms._weights_column == null) {
_trainWithWeights = new Frame(train());
Vec weights = _trainWithWeights.anyVec().makeCons(1,1,null,null)[0];
_weightsName = _trainWithWeights.uniquify(_weightsName); // be sure that we are not accidentally using some column in the train
_trainWithWeights.add(_weightsName, weights);
DKV.put(_trainWithWeights);
Scope.track(weights);
} else {
_trainWithWeights = _parms.train();
}

for (int n = 0; n < _parms._nlearners; n++) {
Timer timer = new Timer();
ModelBuilder job = chooseWeakLearner(_trainWithWeights);
job._parms._seed += n;
Model model = (Model) job.trainModel().get();
DKV.put(model);
Scope.untrack(model._key);
_model._output.models[n] = model._key;
Frame predictions = model.score(_trainWithWeights);
Scope.track(predictions);

CountWeTask countWe = new CountWeTask().doAll(_trainWithWeights.vec(_weightsName), _trainWithWeights.vec(_parms._response_column), predictions.vec("predict"));
double eM = countWe.We / countWe.W;
double alphaM = _parms._learn_rate * Math.log((1 - eM) / eM);
_model._output.alphas[n] = alphaM;

UpdateWeightsTask updateWeightsTask = new UpdateWeightsTask(alphaM);
updateWeightsTask.doAll(_trainWithWeights.vec(_weightsName), _trainWithWeights.vec(_parms._response_column), predictions.vec("predict"));
_job.update(1);
_model.update(_job);
LOG.info((n + 1) + ". estimator was built in " + timer.toString());
LOG.info("*********************************************************************");
}
if (_trainWithWeights != _parms.train()) {
DKV.remove(_trainWithWeights._key);
}
_model._output._model_summary = createModelSummaryTable();
}
}

@Override
protected Driver trainModelImpl() {
return new AdaBoostDriver();
}

@Override
public BuilderVisibility builderVisibility() {
return BuilderVisibility.Experimental;
}

@Override
public ModelCategory[] can_build() {
return new ModelCategory[]{
ModelCategory.Binomial,
};
}

@Override
public boolean isSupervised() {
return true;
}

private ModelBuilder chooseWeakLearner(Frame frame) {
switch (_parms._weak_learner) {
case GLM:
return getGLMWeakLearner(frame);
case GBM:
return getGBMWeakLearner(frame);
default:
case DRF:
return getDRFWeakLearner(frame);

}
}

private DRF getDRFWeakLearner(Frame frame) {
DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
parms._train = frame._key;
parms._response_column = _parms._response_column;
parms._weights_column = _weightsName;
parms._mtries = 1;
valenad1 marked this conversation as resolved.
Show resolved Hide resolved
parms._min_rows = 1;
parms._ntrees = 1;
parms._sample_rate = 1;
parms._max_depth = 1;
parms._seed = _parms._seed;
return new DRF(parms);
}

private GLM getGLMWeakLearner(Frame frame) {
GLMModel.GLMParameters parms = new GLMModel.GLMParameters();
parms._train = frame._key;
parms._response_column = _parms._response_column;
parms._weights_column = _weightsName;
parms._seed = _parms._seed;
return new GLM(parms);
}

private GBM getGBMWeakLearner(Frame frame) {
GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
parms._train = frame._key;
parms._response_column = _parms._response_column;
parms._weights_column = _weightsName;
parms._min_rows = 1;
parms._ntrees = 1;
parms._sample_rate = 1;
parms._max_depth = 1;
parms._seed = _parms._seed;
return new GBM(parms);
}

public TwoDimTable createModelSummaryTable() {
List<String> colHeaders = new ArrayList<>();
List<String> colTypes = new ArrayList<>();
List<String> colFormat = new ArrayList<>();

colHeaders.add("Number of weak learners"); colTypes.add("int"); colFormat.add("%d");
colHeaders.add("Learn rate"); colTypes.add("int"); colFormat.add("%d");
colHeaders.add("Weak learner"); colTypes.add("int"); colFormat.add("%d");
colHeaders.add("Seed"); colTypes.add("long"); colFormat.add("%d");

final int rows = 1;
TwoDimTable table = new TwoDimTable(
"Model Summary", null,
new String[rows],
colHeaders.toArray(new String[0]),
colTypes.toArray(new String[0]),
colFormat.toArray(new String[0]),
"");
int row = 0;
int col = 0;
table.set(row, col++, _parms._nlearners);
table.set(row, col++, _parms._learn_rate);
table.set(row, col++, _parms._weak_learner.toString());
table.set(row, col, _parms._seed);
return table;
}

}
133 changes: 133 additions & 0 deletions h2o-algos/src/main/java/hex/adaboost/AdaBoostModel.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package hex.adaboost;

import hex.Model;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import org.apache.log4j.Logger;
import water.*;

public class AdaBoostModel extends Model<AdaBoostModel, AdaBoostModel.AdaBoostParameters, AdaBoostModel.AdaBoostOutput> {
private static final Logger LOG = Logger.getLogger(AdaBoostModel.class);

public enum Algorithm {DRF, GLM, GBM, AUTO}

public AdaBoostModel(Key<AdaBoostModel> selfKey, AdaBoostParameters parms,
AdaBoostOutput output) {
super(selfKey, parms, output);
}

@Override
public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
if (_output.getModelCategory() == ModelCategory.Binomial) {
return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
}
throw H2O.unimpl("AdaBoost currently support only binary classification");
}

@Override
protected String[] makeScoringNames(){
return new String[]{"predict", "p0", "p1"};
}

@Override
protected double[] score0(double[] data, double[] preds) {
double alphas0 = 0;
double alphas1 = 0;
double linearCombination = 0;
for (int i = 0; i < _output.alphas.length; i++) {
Model model = DKV.getGet(_output.models[i]);
if (model.score(data) == 0) {
linearCombination += _output.alphas[i]*-1;
alphas0 += _output.alphas[i];
} else {
linearCombination += _output.alphas[i];
alphas1 += _output.alphas[i];
valenad1 marked this conversation as resolved.
Show resolved Hide resolved
}
}
preds[0] = alphas0 > alphas1 ? 0 : 1;
preds[2] = 1/(1 + Math.exp(-2*linearCombination));
preds[1] = 1 - preds[2];
return preds;
}

@Override protected boolean needsPostProcess() { return false; /* pred[0] is already set by score0 */ }

public static class AdaBoostOutput extends Model.Output {
public double[] alphas;
public Key<Model>[] models;

public AdaBoostOutput(AdaBoost adaBoostModel) {
super(adaBoostModel);
}
}

@Override
protected Futures remove_impl(Futures fs, boolean cascade) {
for (Key<Model> iTreeKey : _output.models) {
Keyed.remove(iTreeKey, fs, true);
}
return super.remove_impl(fs, cascade);
}

@Override
protected AutoBuffer writeAll_impl(AutoBuffer ab) {
for (Key<Model> iTreeKey : _output.models) {
ab.putKey(iTreeKey);
}
return super.writeAll_impl(ab);
}

@Override
protected Keyed readAll_impl(AutoBuffer ab, Futures fs) {
for (Key<Model> iTreeKey : _output.models) {
ab.getKey(iTreeKey, fs);
}
return super.readAll_impl(ab,fs);
}

public static class AdaBoostParameters extends Model.Parameters {

/**
* Number of weak learners to train. Defaults to 50.
*/
public int _nlearners;

/**
* Choose a weak learner type. Defaults to DRF.
*/
public Algorithm _weak_learner;

/**
* Specify how quickly the training converge. Number in (0,1]. Defaults to 0.5.
*/
public double _learn_rate;

@Override
public String algoName() {
return "AdaBoost";
}

@Override
public String fullName() {
return "AdaBoost";
}

@Override
public String javaName() {
return AdaBoostModel.class.getName();
}

@Override
public long progressUnits() {
return _nlearners;
}

public AdaBoostParameters() {
super();
_nlearners = 50;
_weak_learner = Algorithm.AUTO;
_learn_rate = 0.5;
}
}
}
Loading