Skip to content

Commit

Permalink
GH-15556 Implement UpliftDRF MOJO (#15615)
Browse files Browse the repository at this point in the history
Implement UpliftDRF MOJO.
  • Loading branch information
maurever authored Sep 26, 2023
1 parent 2abadea commit 8856059
Show file tree
Hide file tree
Showing 30 changed files with 617 additions and 167 deletions.
13 changes: 12 additions & 1 deletion h2o-algos/src/main/java/hex/schemas/UpliftDRFModelV3.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package hex.schemas;

import hex.tree.uplift.UpliftDRFModel;
import water.api.API;

public class UpliftDRFModelV3 extends SharedTreeModelV3<UpliftDRFModel,
UpliftDRFModelV3,
Expand All @@ -9,8 +10,18 @@ public class UpliftDRFModelV3 extends SharedTreeModelV3<UpliftDRFModel,
UpliftDRFModel.UpliftDRFOutput,
UpliftDRFModelV3.UpliftDRFModelOutputV3> {

public static final class UpliftDRFModelOutputV3 extends SharedTreeModelV3.SharedTreeModelOutputV3<UpliftDRFModel.UpliftDRFOutput, UpliftDRFModelOutputV3> {}
public static final class UpliftDRFModelOutputV3 extends SharedTreeModelV3.SharedTreeModelOutputV3<UpliftDRFModel.UpliftDRFOutput, UpliftDRFModelOutputV3> {

@API(help="Default thresholds to calculate AUUC metric. If validation is enabled, thresholds from validation metrics is saved here. Otherwise thresholds are from training metrics.")
public double[] default_auuc_thresholds;

@Override public UpliftDRFModelV3.UpliftDRFModelOutputV3 fillFromImpl(UpliftDRFModel.UpliftDRFOutput impl) {
UpliftDRFModelV3.UpliftDRFModelOutputV3 uov3 = super.fillFromImpl(impl);
uov3.default_auuc_thresholds = impl._defaultAuucThresholds;
return uov3;
}
}

public UpliftDRFV3.UpliftDRFParametersV3 createParametersSchema() { return new UpliftDRFV3.UpliftDRFParametersV3(); }
public UpliftDRFModelOutputV3 createOutputSchema() { return new UpliftDRFModelOutputV3(); }

Expand Down
1 change: 1 addition & 0 deletions h2o-algos/src/main/java/hex/tree/Score.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import hex.genmodel.GenModel;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.gbm.GBMModel;
import hex.tree.uplift.UpliftDRFModel;
import org.apache.log4j.Logger;
import water.Iced;
import water.Key;
Expand Down
1 change: 0 additions & 1 deletion h2o-algos/src/main/java/hex/tree/SharedTree.java
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,6 @@ protected final boolean doScoringAndSaveModel(boolean finalScoring, boolean oob,
out._training_metrics = mm;
if (oob) out._training_metrics._description = "Metrics reported on Out-Of-Bag training samples";
out._scored_train[out._ntrees].fillFrom(mm);

// Score again on validation data
if( _parms._valid != null) {
Frame v = new Frame(valid());
Expand Down
3 changes: 2 additions & 1 deletion h2o-algos/src/main/java/hex/tree/SharedTreeModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import hex.genmodel.algos.tree.SharedTreeMojoModel;
import hex.genmodel.algos.tree.SharedTreeNode;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import hex.glm.GLMModel;
import hex.tree.uplift.UpliftDRFModel;
import hex.util.LinearAlgebraUtils;
import org.apache.log4j.Logger;
import water.*;
Expand Down Expand Up @@ -166,6 +166,7 @@ public boolean forceStrictlyReproducibleHistograms() {
case Binomial: return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
case Multinomial: return new ModelMetricsMultinomial.MetricBuilderMultinomial(_output.nclasses(),domain, _parms._auc_type);
case Regression: return new ModelMetricsRegression.MetricBuilderRegression();
case BinomialUplift: return new ModelMetricsBinomialUplift.MetricBuilderBinomialUplift(domain, ((UpliftDRFModel.UpliftDRFOutput)_output)._defaultAuucThresholds);
default: throw H2O.unimpl();
}
}
Expand Down
22 changes: 20 additions & 2 deletions h2o-algos/src/main/java/hex/tree/uplift/UpliftDRF.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package hex.tree.uplift;

import hex.*;
import hex.genmodel.MojoModel;
import hex.genmodel.algos.upliftdrf.UpliftDrfMojoModel;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.*;
import org.apache.log4j.Logger;
Expand Down Expand Up @@ -50,12 +52,12 @@ public UpliftDRF(boolean startup_once) {

@Override
public boolean haveMojo() {
return false;
return true;
}

@Override
public boolean havePojo() {
return false;
return true;
}

@Override
Expand Down Expand Up @@ -472,6 +474,22 @@ static TwoDimTable createUpliftScoringHistoryTable(Model.Output _output,
return table;
}

@Override
public PojoWriter makePojoWriter(Model<?, ?, ?> genericModel, MojoModel mojoModel) {
UpliftDrfMojoModel upliftDrfMojoModel = (UpliftDrfMojoModel) mojoModel;
CompressedTree[][] trees = MojoUtils.extractCompressedTrees(upliftDrfMojoModel);
return new UpliftDrfPojoWriter(genericModel, upliftDrfMojoModel.getCategoricalEncoding(), false, trees, upliftDrfMojoModel._balanceClasses);
}

@Override
protected void addCustomInfo(UpliftDRFModel.UpliftDRFOutput out) {
if(out._validation_metrics != null){
out._defaultAuucThresholds = ((ModelMetricsBinomialUplift)out._validation_metrics)._auuc._ths;
} else {
out._defaultAuucThresholds = ((ModelMetricsBinomialUplift)out._training_metrics)._auuc._ths;
}
}

@Override
protected UpliftScoreExtension makeScoreExtension() {
return new UpliftScoreExtension();
Expand Down
29 changes: 22 additions & 7 deletions h2o-algos/src/main/java/hex/tree/uplift/UpliftDRFModel.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package hex.tree.uplift;

import hex.*;
import hex.tree.CompressedForest;
import hex.tree.SharedTreeModel;
import hex.tree.SharedTreeModelWithContributions;
import hex.tree.SharedTreePojoWriter;
import hex.util.EffectiveParametersUtils;
import water.H2O;
import water.Key;

public class UpliftDRFModel extends SharedTreeModel<UpliftDRFModel, UpliftDRFModel.UpliftDRFParameters, UpliftDRFModel.UpliftDRFOutput> {
Expand All @@ -13,7 +14,6 @@ public static class UpliftDRFParameters extends SharedTreeModel.SharedTreeParame
public String algoName() { return "UpliftDRF"; }
public String fullName() { return "Uplift Distributed Random Forest"; }
public String javaName() { return UpliftDRFModel.class.getName(); }
public boolean _binomial_double_trees = false;


public enum UpliftMetricType { AUTO, KL, ChiSquared, Euclidean }
Expand All @@ -36,6 +36,9 @@ public long progressUnits() {
}

public static class UpliftDRFOutput extends SharedTreeModelWithContributions.SharedTreeOutput {

public double[] _defaultAuucThresholds; // thresholds for AUUC to calculate metrics

public UpliftDRFOutput( UpliftDRF b) { super(b); }

@Override
Expand All @@ -45,7 +48,11 @@ public ModelCategory getModelCategory() {

@Override
public boolean isBinomialClassifier() {
return false;
return true;
}

public void setDefaultAuucThresholds(double[] defaultAuucThresholds) {
this._defaultAuucThresholds = defaultAuucThresholds;
}
}

Expand Down Expand Up @@ -77,10 +84,18 @@ public void initActualParamValues() {
}

@Override public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
if (_output.getModelCategory() == ModelCategory.BinomialUplift) {
return new ModelMetricsBinomialUplift.MetricBuilderBinomialUplift(domain);
}
throw H2O.unimpl();
return new ModelMetricsBinomialUplift.MetricBuilderBinomialUplift(domain, _output._defaultAuucThresholds);
}

@Override
public UpliftDrfMojoWriter getMojo() {
return new UpliftDrfMojoWriter(this);
}

@Override
protected SharedTreePojoWriter makeTreePojoWriter() {
CompressedForest compressedForest = new CompressedForest(_output._treeKeys, _output._domains);
CompressedForest.LocalCompressedForest localCompressedForest = compressedForest.fetch();
return new UpliftDrfPojoWriter(this, localCompressedForest._trees);
}
}
23 changes: 23 additions & 0 deletions h2o-algos/src/main/java/hex/tree/uplift/UpliftDrfMojoWriter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package hex.tree.uplift;

import hex.tree.SharedTreeMojoWriter;

import java.io.IOException;

public class UpliftDrfMojoWriter extends SharedTreeMojoWriter<UpliftDRFModel, UpliftDRFModel.UpliftDRFParameters, UpliftDRFModel.UpliftDRFOutput> {

@SuppressWarnings("unused") // Called through reflection in ModelBuildersHandler
public UpliftDrfMojoWriter() {}

public UpliftDrfMojoWriter(UpliftDRFModel model) { super(model); }

@Override public String mojoVersion() {
return "1.40";
}

@Override
protected void writeModelData() throws IOException {
super.writeModelData();
writekv("default_auuc_thresholds", model._output._defaultAuucThresholds);
}
}
28 changes: 28 additions & 0 deletions h2o-algos/src/main/java/hex/tree/uplift/UpliftDrfPojoWriter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package hex.tree.uplift;

import hex.Model;
import hex.genmodel.CategoricalEncoding;
import hex.tree.CompressedTree;
import hex.tree.SharedTreePojoWriter;
import water.util.SBPrintStream;

public class UpliftDrfPojoWriter extends SharedTreePojoWriter {

UpliftDrfPojoWriter(UpliftDRFModel model, CompressedTree[][] trees) {
super(model._key, model._output, model.getGenModelEncoding(), model.binomialOpt(),
trees, model._output._treeStats);
}

UpliftDrfPojoWriter(Model<?, ?, ?> model, CategoricalEncoding encoding,
boolean binomialOpt, CompressedTree[][] trees,
boolean balanceClasses) {
super(model._key, model._output, encoding, binomialOpt, trees, null);
}

@Override
protected void toJavaUnifyPreds(SBPrintStream body) {
body.ip("preds[1] /= " + _trees.length + ";").nl();
body.ip("preds[2] /= " + _trees.length + ";").nl();
body.ip("preds[0] = preds[1] - preds[2]");
}
}
47 changes: 47 additions & 0 deletions h2o-algos/src/test/java/hex/tree/gbm/GBMTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4458,6 +4458,53 @@ public void testResetThreshold() throws Exception {
}
}

@Test
public void testMojoMetrics() throws Exception {
GBMModel gbm = null;
try {
Scope.enter();
Frame frame = new TestFrameBuilder()
.withName("data")
.withColNames("ColA", "ColB", "Response")
.withVecTypes(Vec.T_NUM, Vec.T_NUM, Vec.T_NUM)
.withDataForCol(0, ard(0, 1, 0, 1, 0, 1, 0))
.withDataForCol(1, ard(Double.NaN, 1, 2, 3, 4, 5.6, 7))
.withDataForCol(2, ard(1, 0, 1, 1, 1, 0, 1))
.build();

frame = frame.toCategoricalCol(2);

Frame frameVal = new TestFrameBuilder()
.withName("dataVal")
.withColNames("ColA", "ColB", "Response")
.withVecTypes(Vec.T_NUM, Vec.T_NUM, Vec.T_NUM)
.withDataForCol(0, ard(0, 1, 1, 1, 0, 0, 1))
.withDataForCol(1, ard(Double.NaN, 1, 3, 2, 4, 8, 7))
.withDataForCol(2, ard(1, 1, 1, 0, 0, 1, 1))
.build();

frameVal = frameVal.toCategoricalCol(2);

GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
parms._train = frame._key;
parms._valid = frameVal._key;
parms._response_column = "Response";
parms._ntrees = 1;
parms._min_rows = 0.1;
parms._distribution = bernoulli;

gbm = new GBM(parms).trainModel().get();
Scope.track_generic(gbm);
Frame train_score = gbm.score(frame);
Scope.track_generic(train_score);

assertTrue(gbm.testJavaScoring(frame, train_score, 1e-15));

} finally {
Scope.exit();
}
}

@Test
public void testGBMFeatureInteractions() {
Scope.enter();
Expand Down
94 changes: 94 additions & 0 deletions h2o-algos/src/test/java/hex/tree/uplift/UpliftDRFTest.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
package hex.tree.uplift;

import hex.ScoreKeeper;
import hex.genmodel.MojoModel;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.prediction.UpliftBinomialModelPrediction;
import hex.genmodel.utils.ArrayUtils;
import hex.genmodel.utils.DistributionFamily;
import org.junit.Assume;
import org.junit.Test;
import org.junit.runner.RunWith;
import water.H2O;
import water.Scope;
import water.TestUtil;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
Expand All @@ -14,7 +20,14 @@
import water.runner.CloudSize;
import water.runner.H2ORunner;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import static org.junit.Assert.*;

Expand Down Expand Up @@ -356,4 +369,85 @@ public void testPredictCorrectOutput() {
Scope.exit();
}
}

@Test
public void testMojo() {
try {
Scope.enter();
Frame train = new TestFrameBuilder()
.withColNames("C0", "C1", "treatment", "conversion")
.withVecTypes(Vec.T_NUM, Vec.T_NUM, Vec.T_CAT, Vec.T_CAT)
.withDataForCol(0, ard(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0))
.withDataForCol(1, ard(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0))
.withDataForCol(2, ar("T", "C", "T", "T", "T", "C", "C", "C", "C", "C"))
.withDataForCol(3, ar("Yes", "No", "Yes", "No", "Yes", "No", "Yes", "No", "Yes", "Yes"))
.build();
train.toCategoricalCol("treatment");
train.toCategoricalCol("conversion");
UpliftDRFModel.UpliftDRFParameters p = new UpliftDRFModel.UpliftDRFParameters();
p._train = train._key;
p._response_column = "conversion";
p._treatment_column = "treatment";
p._ntrees = 4;

UpliftDRF udrf = new UpliftDRF(p);
UpliftDRFModel model = udrf.trainModel().get();
Scope.track_generic(model);
Frame preds = model.score(train);
Scope.track_generic(preds);

assertTrue(model.testJavaScoring(train, preds,1e-15));
} finally {
Scope.exit();
}
}

@Test
public void testEasyPredictMojo() throws Exception {
try {
Scope.enter();
Frame train = new TestFrameBuilder()
.withColNames("C0", "C1", "treatment", "conversion")
.withVecTypes(Vec.T_NUM, Vec.T_NUM, Vec.T_CAT, Vec.T_CAT)
.withDataForCol(0, ard(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0))
.withDataForCol(1, ard(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0))
.withDataForCol(2, ar("T", "C", "T", "T", "T", "C", "C", "C", "C", "C"))
.withDataForCol(3, ar("Yes", "No", "Yes", "No", "Yes", "No", "Yes", "No", "Yes", "Yes"))
.build();
train.toCategoricalCol("treatment");
train.toCategoricalCol("conversion");
Scope.track_generic(train);
UpliftDRFModel.UpliftDRFParameters p = new UpliftDRFModel.UpliftDRFParameters();
p._train = train._key;
p._response_column = "conversion";
p._treatment_column = "treatment";
p._ntrees = 4;

UpliftDRF udrf = new UpliftDRF(p);
UpliftDRFModel model = udrf.trainModel().get();
Scope.track_generic(model);
MojoModel mojo = model.toMojo();
EasyPredictModelWrapper wrapper = new EasyPredictModelWrapper(
new EasyPredictModelWrapper.Config()
.setModel(mojo)
.setEnableContributions(false)
);
Frame featureFr = train.subframe(mojo.features());
Scope.track_generic(featureFr);
for (int i = 0; i < featureFr.numRows(); i++) {
RowData row = new RowData();
for (String feat : featureFr.names()) {
if (!featureFr.vec(feat).isNA(i)) {
double value = featureFr.vec(feat).at(i);
row.put(feat, value);
}
}
UpliftBinomialModelPrediction pred = wrapper.predictUpliftBinomial(row);
assertEquals(pred.predictions.length,3);
assertEquals(pred.predictions[0], pred.predictions[1]-pred.predictions[2], 0);
}
} finally {
Scope.exit();
}
}
}
Loading

0 comments on commit 8856059

Please sign in to comment.